-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds lightning fabric integration (#749)
* wip lightning fabric logger * add fabric tests * lightning: fix test_lightning_val_updates_to_studio * lightning: move next_step logic to fabric * add sync method and use in lightning/fabric callbacks * fabric: auto-increment step * add fabric example notebook * update fabric notebook * fix mypy errors * skip fabric tests if not installed * fix(project): proper lighting extra name reference * fix(mypy): remove options that is default now * add back nox * fix linting issues * make torch import optional --------- Co-authored-by: Ivan Shcheklein <[email protected]>
- Loading branch information
1 parent
a7b7f55
commit b8a8ecf
Showing
8 changed files
with
595 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "QKSE19fW_Dnj" | ||
}, | ||
"source": [ | ||
"# DVCLive and Lightning Fabric" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "q-C_4R_o_QGG" | ||
}, | ||
"source": [ | ||
"## Install dvclive" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "-XFbvwq7TSwN", | ||
"outputId": "15d0e3b5-bb4a-4b3e-d37f-21608d1822ed" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install \"dvclive[lightning]\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "I6S6Uru1_Y0x" | ||
}, | ||
"source": [ | ||
"## Initialize DVC Repository" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "WcbvUl2uTV0y", | ||
"outputId": "aff9740c-26db-483d-ce30-cfef395f3cbb" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!git init -q\n", | ||
"!git config --local user.email \"[email protected]\"\n", | ||
"!git config --local user.name \"Your Name\"\n", | ||
"!dvc init -q\n", | ||
"!git commit -m \"DVC init\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "LmY4PLMh_cUk" | ||
}, | ||
"source": [ | ||
"## Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "85qErT5yTEbN" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import argparse\n", | ||
"from os import path\n", | ||
"from types import SimpleNamespace\n", | ||
"\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"import torch.optim as optim\n", | ||
"import torchvision.transforms as T\n", | ||
"from lightning.fabric import Fabric, seed_everything\n", | ||
"from lightning.fabric.utilities.rank_zero import rank_zero_only\n", | ||
"from torch.optim.lr_scheduler import StepLR\n", | ||
"from torchmetrics.classification import Accuracy\n", | ||
"from torchvision.datasets import MNIST\n", | ||
"\n", | ||
"from dvclive.fabric import DVCLiveLogger\n", | ||
"\n", | ||
"DATASETS_PATH = (\"Datasets\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "UrmAHbhr_lgs" | ||
}, | ||
"source": [ | ||
"## Setup model code\n", | ||
"\n", | ||
"Adapted from https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/fabric/image_classifier/train_fabric.py.\n", | ||
"\n", | ||
"Look for the `logger` statements where DVCLiveLogger calls were added." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "UCzTygUnTHM8" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"class Net(nn.Module):\n", | ||
" def __init__(self) -> None:\n", | ||
" super().__init__()\n", | ||
" self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", | ||
" self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", | ||
" self.dropout1 = nn.Dropout(0.25)\n", | ||
" self.dropout2 = nn.Dropout(0.5)\n", | ||
" self.fc1 = nn.Linear(9216, 128)\n", | ||
" self.fc2 = nn.Linear(128, 10)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = self.conv1(x)\n", | ||
" x = F.relu(x)\n", | ||
" x = self.conv2(x)\n", | ||
" x = F.relu(x)\n", | ||
" x = F.max_pool2d(x, 2)\n", | ||
" x = self.dropout1(x)\n", | ||
" x = torch.flatten(x, 1)\n", | ||
" x = self.fc1(x)\n", | ||
" x = F.relu(x)\n", | ||
" x = self.dropout2(x)\n", | ||
" x = self.fc2(x)\n", | ||
" return F.log_softmax(x, dim=1)\n", | ||
"\n", | ||
"\n", | ||
"def run(hparams):\n", | ||
" # Create the DVCLive Logger\n", | ||
" logger = DVCLiveLogger(report=\"notebook\")\n", | ||
"\n", | ||
" # Log dict of hyperparameters\n", | ||
" logger.log_hyperparams(hparams.__dict__)\n", | ||
"\n", | ||
" # Create the Lightning Fabric object. The parameters like accelerator, strategy, devices etc. will be proided\n", | ||
" # by the command line. See all options: `lightning run model --help`\n", | ||
" fabric = Fabric()\n", | ||
"\n", | ||
" seed_everything(hparams.seed) # instead of torch.manual_seed(...)\n", | ||
"\n", | ||
" transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])\n", | ||
"\n", | ||
" # Let rank 0 download the data first, then everyone will load MNIST\n", | ||
" with fabric.rank_zero_first(local=False): # set `local=True` if your filesystem is not shared between machines\n", | ||
" train_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=True, transform=transform)\n", | ||
" test_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=False, transform=transform)\n", | ||
"\n", | ||
" train_loader = torch.utils.data.DataLoader(\n", | ||
" train_dataset,\n", | ||
" batch_size=hparams.batch_size,\n", | ||
" )\n", | ||
" test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size)\n", | ||
"\n", | ||
" # don't forget to call `setup_dataloaders` to prepare for dataloaders for distributed training.\n", | ||
" train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)\n", | ||
"\n", | ||
" model = Net() # remove call to .to(device)\n", | ||
" optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr)\n", | ||
"\n", | ||
" # don't forget to call `setup` to prepare for model / optimizer for distributed training.\n", | ||
" # the model is moved automatically to the right device.\n", | ||
" model, optimizer = fabric.setup(model, optimizer)\n", | ||
"\n", | ||
" scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma)\n", | ||
"\n", | ||
" # use torchmetrics instead of manually computing the accuracy\n", | ||
" test_acc = Accuracy(task=\"multiclass\", num_classes=10).to(fabric.device)\n", | ||
"\n", | ||
" # EPOCH LOOP\n", | ||
" for epoch in range(1, hparams.epochs + 1):\n", | ||
" # TRAINING LOOP\n", | ||
" model.train()\n", | ||
" for batch_idx, (data, target) in enumerate(train_loader):\n", | ||
" # NOTE: no need to call `.to(device)` on the data, target\n", | ||
" optimizer.zero_grad()\n", | ||
" output = model(data)\n", | ||
" loss = F.nll_loss(output, target)\n", | ||
" fabric.backward(loss) # instead of loss.backward()\n", | ||
"\n", | ||
" optimizer.step()\n", | ||
" if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0):\n", | ||
" print(\n", | ||
" \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n", | ||
" epoch,\n", | ||
" batch_idx * len(data),\n", | ||
" len(train_loader.dataset),\n", | ||
" 100.0 * batch_idx / len(train_loader),\n", | ||
" loss.item(),\n", | ||
" )\n", | ||
" )\n", | ||
"\n", | ||
" # Log dict of metrics\n", | ||
" logger.log_metrics({\"loss\": loss.item()})\n", | ||
"\n", | ||
" if hparams.dry_run:\n", | ||
" break\n", | ||
"\n", | ||
" scheduler.step()\n", | ||
"\n", | ||
" # TESTING LOOP\n", | ||
" model.eval()\n", | ||
" test_loss = 0\n", | ||
" with torch.no_grad():\n", | ||
" for data, target in test_loader:\n", | ||
" # NOTE: no need to call `.to(device)` on the data, target\n", | ||
" output = model(data)\n", | ||
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n", | ||
"\n", | ||
" # WITHOUT TorchMetrics\n", | ||
" # pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", | ||
" # correct += pred.eq(target.view_as(pred)).sum().item()\n", | ||
"\n", | ||
" # WITH TorchMetrics\n", | ||
" test_acc(output, target)\n", | ||
"\n", | ||
" if hparams.dry_run:\n", | ||
" break\n", | ||
"\n", | ||
" # all_gather is used to aggregated the value across processes\n", | ||
" test_loss = fabric.all_gather(test_loss).sum() / len(test_loader.dataset)\n", | ||
"\n", | ||
" print(f\"\\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({100 * test_acc.compute():.0f}%)\\n\")\n", | ||
"\n", | ||
" # log additional metrics\n", | ||
" logger.log_metrics({\"test_loss\": test_loss, \"test_acc\": 100 * test_acc.compute()})\n", | ||
"\n", | ||
" test_acc.reset()\n", | ||
"\n", | ||
" if hparams.dry_run:\n", | ||
" break\n", | ||
"\n", | ||
" # When using distributed training, use `fabric.save`\n", | ||
" # to ensure the current process is allowed to save a checkpoint\n", | ||
" if hparams.save_model:\n", | ||
" fabric.save(\"mnist_cnn.pt\", model.state_dict())\n", | ||
"\n", | ||
" # `logger.experiment` provides access to the `dvclive.Live` instance where you can use additional logging methods.\n", | ||
" # Check that `rank_zero_only.rank == 0` to avoid logging in other processes.\n", | ||
" if rank_zero_only.rank == 0:\n", | ||
" logger.experiment.log_artifact(\"mnist_cnn.pt\")\n", | ||
"\n", | ||
" # Call finalize to save final results as a DVC experiment\n", | ||
" logger.finalize(\"success\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "o5_v9lRDAM7l" | ||
}, | ||
"source": [ | ||
"## Train the model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 1000 | ||
}, | ||
"id": "BbCXen1PTM4V", | ||
"outputId": "b79c90eb-74cc-474d-c0dd-21245064bca8" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"hparams = SimpleNamespace(batch_size=64, epochs=5, lr=1.0, gamma=0.7, dry_run=False, seed=1, log_interval=10, save_model=True)\n", | ||
"run(hparams)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "DnqCrlbLAopV" | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.