Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Dec 20, 2024
1 parent 67e6c04 commit 5fcdc4c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
44 changes: 26 additions & 18 deletions examples/DVCLive-Fabric.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
},
"outputs": [],
"source": [
"!pip install \"dvclive[lightning]\""
"%pip install \"dvclive[lightning]\""
]
},
{
Expand Down Expand Up @@ -82,9 +82,9 @@
"\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"import torch.nn.functional as F # noqa: N812\n",
"from torch import optim\n",
"import torchvision.transforms as T\n",
"import torchvision.transforms as T # noqa: N812\n",
"from lightning.fabric import Fabric, seed_everything\n",
"from lightning.fabric.utilities.rank_zero import rank_zero_only\n",
"from torch.optim.lr_scheduler import StepLR\n",
Expand Down Expand Up @@ -149,8 +149,9 @@
" # Log dict of hyperparameters\n",
" logger.log_hyperparams(hparams.__dict__)\n",
"\n",
" # Create the Lightning Fabric object. The parameters like accelerator, strategy, devices etc. will be proided\n",
" # by the command line. See all options: `lightning run model --help`\n",
" # Create the Lightning Fabric object. The parameters like accelerator, strategy,\n",
" # devices etc. will be proided by the command line. See all options: `lightning\n",
" # run model --help`\n",
" fabric = Fabric()\n",
"\n",
" seed_everything(hparams.seed) # instead of torch.manual_seed(...)\n",
Expand Down Expand Up @@ -182,14 +183,15 @@
" test_dataset, batch_size=hparams.batch_size\n",
" )\n",
"\n",
" # don't forget to call `setup_dataloaders` to prepare for dataloaders for distributed training.\n",
" # don't forget to call `setup_dataloaders` to prepare for dataloaders for\n",
" # distributed training.\n",
" train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)\n",
"\n",
" model = Net() # remove call to .to(device)\n",
" optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr)\n",
"\n",
" # don't forget to call `setup` to prepare for model / optimizer for distributed training.\n",
" # the model is moved automatically to the right device.\n",
" # don't forget to call `setup` to prepare for model / optimizer for\n",
" # distributed training. The model is moved automatically to the right device.\n",
" model, optimizer = fabric.setup(model, optimizer)\n",
"\n",
" scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma)\n",
Expand All @@ -210,8 +212,10 @@
"\n",
" optimizer.step()\n",
" if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0):\n",
" print(\n",
" f\"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n",
" done = (batch_idx * len(data)) / len(train_loader.dataset)\n",
" pct = 100.0 * batch_idx / len(train_loader)\n",
" print( # noqa: T201\n",
" f\"-> Epoch: {epoch} [{done} ({pct:.0f}%)]\\tLoss: {loss.item():.6f}\"\n",
" )\n",
"\n",
" # Log dict of metrics\n",
Expand All @@ -232,8 +236,8 @@
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
"\n",
" # WITHOUT TorchMetrics\n",
" # pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
" # correct += pred.eq(target.view_as(pred)).sum().item()\n",
" # pred = output.argmax(dim=1, keepdim=True) # get the index of the max\n",
" # log-probability correct += pred.eq(target.view_as(pred)).sum().item()\n",
"\n",
" # WITH TorchMetrics\n",
" test_acc(output, target)\n",
Expand All @@ -243,9 +247,10 @@
"\n",
" # all_gather is used to aggregated the value across processes\n",
" test_loss = fabric.all_gather(test_loss).sum() / len(test_loader.dataset)\n",
" acc = 100 * test_acc.compute()\n",
"\n",
" print(\n",
" f\"\\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({100 * test_acc.compute():.0f}%)\\n\"\n",
" print( # noqa: T201\n",
" f\"\\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({acc:.0f}%)\\n\"\n",
" )\n",
"\n",
" # log additional metrics\n",
Expand All @@ -263,8 +268,9 @@
" if hparams.save_model:\n",
" fabric.save(\"mnist_cnn.pt\", model.state_dict())\n",
"\n",
" # `logger.experiment` provides access to the `dvclive.Live` instance where you can use additional logging methods.\n",
" # Check that `rank_zero_only.rank == 0` to avoid logging in other processes.\n",
" # `logger.experiment` provides access to the `dvclive.Live` instance where you\n",
" # can use additional logging methods. Check that `rank_zero_only.rank == 0` to\n",
" # avoid logging in other processes.\n",
" if rank_zero_only.rank == 0:\n",
" logger.experiment.log_artifact(\"mnist_cnn.pt\")\n",
"\n",
Expand Down Expand Up @@ -322,11 +328,13 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"name": "python",
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down
11 changes: 5 additions & 6 deletions examples/DVCLive-PyTorch-Lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
},
"outputs": [],
"source": [
"!pip install \"dvclive[lightning]\""
"%pip install \"dvclive[lightning]\""
]
},
{
Expand Down Expand Up @@ -75,7 +75,7 @@
"\n",
"\n",
"class LitAutoEncoder(pl.LightningModule):\n",
" def __init__(self, encoder_size=64, lr=1e-3):\n",
" def __init__(self, encoder_size=64, lr=1e-3): # noqa: ARG002\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" self.encoder = torch.nn.Sequential(\n",
Expand All @@ -89,7 +89,7 @@
" torch.nn.Linear(encoder_size, 28 * 28),\n",
" )\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" def training_step(self, batch, batch_idx): # noqa: ARG002\n",
" x, y = batch\n",
" x = x.view(x.size(0), -1)\n",
" z = self.encoder(x)\n",
Expand All @@ -98,7 +98,7 @@
" self.log(\"train_mse\", train_mse)\n",
" return train_mse\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" def validation_step(self, batch, batch_idx): # noqa: ARG002\n",
" x, y = batch\n",
" x = x.view(x.size(0), -1)\n",
" z = self.encoder(x)\n",
Expand All @@ -108,8 +108,7 @@
" return val_mse\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)\n",
" return optimizer"
" return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)"
]
},
{
Expand Down
11 changes: 5 additions & 6 deletions examples/DVCLive-Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install dvclive"
"%pip install dvclive"
]
},
{
Expand Down Expand Up @@ -97,8 +97,7 @@
" \"\"\"Get model prediction scores.\"\"\"\n",
" model.eval()\n",
" with torch.no_grad():\n",
" y_pred = model(x)\n",
" return y_pred\n",
" return model(x)\n",
"\n",
"\n",
"def get_metrics(y, y_pred, y_pred_label):\n",
Expand Down Expand Up @@ -156,14 +155,14 @@
" out_matrix[xs:xe, ys:ye, c] = (1 - image) * label_color[c]\n",
" out_matrix[ys:ye, xs:xe, c] = (1 - image) * label_color[c]\n",
"\n",
" for i, j in confusion:\n",
" for i, j in confusion: # noqa: PLC0206\n",
" image = confusion[(i, j)]\n",
" assert image.shape == image_shape\n",
" assert image.shape == image_shape # noqa: S101\n",
" xs = (i + 1) * frame_size + 1\n",
" xe = (i + 2) * frame_size - 1\n",
" ys = (j + 1) * frame_size + 1\n",
" ye = (j + 2) * frame_size - 1\n",
" assert (xe - xs, ye - ys) == image_shape\n",
" assert (xe - xs, ye - ys) == image_shape # noqa: S101\n",
" if i != j:\n",
" for c in range(3):\n",
" out_matrix[xs:xe, ys:ye, c] = (1 - image) * incorrect_color[c]\n",
Expand Down

0 comments on commit 5fcdc4c

Please sign in to comment.