diff --git a/pytorch-lightning_ipynb/template/template_classification_basic.ipynb b/pytorch-lightning_ipynb/template/template_classification_basic.ipynb
deleted file mode 100644
index f10012f..0000000
--- a/pytorch-lightning_ipynb/template/template_classification_basic.ipynb
+++ /dev/null
@@ -1,494 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "77bd01a5",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext watermark\n",
- "%watermark -p torch,pytorch_lightning,torchmetrics,matplotlib"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a4c7e37a",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext pycodestyle_magic\n",
- "%flake8_on --ignore W291,W293,E703"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1d1a520c",
- "metadata": {},
- "source": [
- "
\n",
- "\n",
- "# Model Zoo -- DESCRIPTION"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8323564f",
- "metadata": {},
- "source": [
- "- DESCRIPTION\n",
- "\n",
- "\n",
- "### References\n",
- "\n",
- "- ???"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4ec5cdbb",
- "metadata": {},
- "source": [
- "## General settings and hyperparameters"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b9b0a861",
- "metadata": {},
- "source": [
- "- Here, we specify some general hyperparameter values and general settings\n",
- "- Note that for small datatsets, it is not necessary and better not to use multiple workers as it can sometimes cause issues with too many open files in PyTorch. So, if you have problems with the data loader later, try setting `NUM_WORKERS = 0` instead."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fb1d2d02",
- "metadata": {},
- "outputs": [],
- "source": [
- "BATCH_SIZE = 256\n",
- "NUM_EPOCHS = 10\n",
- "LEARNING_RATE = 0.005\n",
- "NUM_WORKERS = 4"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "24b6f86d",
- "metadata": {},
- "source": [
- "## Implementing a Neural Network using PyTorch Lightning's `LightningModule`"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "36c91549",
- "metadata": {},
- "source": [
- "- In this section, we set up the main model architecture using the `LightningModule` from PyTorch Lightning.\n",
- "- We start with defining our neural network model in pure PyTorch, and then we use it in the `LightningModule` to get all the extra benefits that PyTorch Lightning provides."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4eacca3d",
- "metadata": {},
- "outputs": [],
- "source": [
- "# UNIQUE MODEL CODE"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "59d6b481",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_lightningmodule/lightningmodule_classifier_basic.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "24983cbb",
- "metadata": {},
- "source": [
- "## Setting up the dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b6198374",
- "metadata": {},
- "source": [
- "- In this section, we are going to set up our dataset."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "951705b9",
- "metadata": {},
- "source": [
- "### Inspecting the dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "17958c36",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_dataset/dataset_???_check.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "ee5b98a9",
- "metadata": {},
- "source": [
- "### Performance baseline"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f366d74e",
- "metadata": {},
- "source": [
- "- Especially for imbalanced datasets, it's quite useful to compute a performance baseline.\n",
- "- In classification contexts, a useful baseline is to compute the accuracy for a scenario where the model always predicts the majority class -- you want your model to be better than that!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1d8ed6d1",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_dataset/performance_baseline.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0bdd53b7",
- "metadata": {},
- "source": [
- "## A quick visual check"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "084cffe3",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load plot_visual-check_basic.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4a8b3c3b",
- "metadata": {},
- "source": [
- "### Setting up a `DataModule`"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "cbf59787",
- "metadata": {},
- "source": [
- "- There are three main ways we can prepare the dataset for Lightning. We can\n",
- " 1. make the dataset part of the model;\n",
- " 2. set up the data loaders as usual and feed them to the fit method of a Lightning Trainer -- the Trainer is introduced in the next subsection;\n",
- " 3. create a LightningDataModule.\n",
- "- Here, we are going to use approach 3, which is the most organized approach. The `LightningDataModule` consists of several self-explanatory methods as we can see below:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "86d43c10",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_lightningmodule/datamodule_???_basic.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "ce6be803",
- "metadata": {},
- "source": [
- "- Note that the `prepare_data` method is usually used for steps that only need to be executed once, for example, downloading the dataset; the `setup` method defines the the dataset loading -- if you run your code in a distributed setting, this will be called on each node / GPU. \n",
- "- Next, lets initialize the `DataModule`; we use a random seed for reproducibility (so that the data set is shuffled the same way when we re-execute this code):"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "78f7b3a6",
- "metadata": {},
- "outputs": [],
- "source": [
- "torch.manual_seed(1) \n",
- "data_module = DataModule(data_path='./data')"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "59834e86",
- "metadata": {},
- "source": [
- "## Training the model using the PyTorch Lightning Trainer class"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "bee09340",
- "metadata": {},
- "source": [
- "- Next, we initialize our model.\n",
- "- Also, we define a call back so that we can obtain the model with the best validation set performance after training.\n",
- "- PyTorch Lightning offers [many advanced logging services](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html) like Weights & Biases. Here, we will keep things simple and use the `CSVLogger`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "970ae42c",
- "metadata": {},
- "outputs": [],
- "source": [
- "pytorch_model = PyTorchModel(\n",
- " ???\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4e3238a8",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_lightningmodule/logger_csv_acc_basic.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5861ddcb",
- "metadata": {},
- "source": [
- "- Now it's time to train our model:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ccf14578",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_lightningmodule/trainer_nb_basic.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "9c1ea45d",
- "metadata": {},
- "source": [
- "## Evaluating the model"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "9358c717",
- "metadata": {},
- "source": [
- "- After training, let's plot our training ACC and validation ACC using pandas, which, in turn, uses matplotlib for plotting (you may want to consider a [more advanced logger](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html) that does that for you):"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "842ed5b5",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_lightningmodule/logger_csv_plot_basic.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6e426862",
- "metadata": {},
- "source": [
- "- The `trainer` automatically saves the model with the best validation accuracy automatically for us, we which we can load from the checkpoint via the `ckpt_path='best'` argument; below we use the `trainer` instance to evaluate the best model on the test set:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "27ddf180",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0b88f525",
- "metadata": {},
- "source": [
- "## Predicting labels of new data"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "be642c73",
- "metadata": {},
- "source": [
- "- You can use the `trainer.predict` method on a new `DataLoader` or `DataModule` to apply the model to new data.\n",
- "- Alternatively, you can also manually load the best model from a checkpoint as shown below:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5367ee40",
- "metadata": {},
- "outputs": [],
- "source": [
- "path = trainer.checkpoint_callback.best_model_path\n",
- "print(path)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b3dad29f",
- "metadata": {},
- "outputs": [],
- "source": [
- "lightning_model = LightningModel.load_from_checkpoint(path, model=pytorch_model)\n",
- "lightning_model.eval();"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "92211016",
- "metadata": {},
- "source": [
- "- Note that our PyTorch model, which is passed to the Lightning model requires input arguments. However, this is automatically being taken care of since we used `self.save_hyperparameters()` in our PyTorch model's `__init__` method.\n",
- "- Now, below is an example applying the model manually. Here, pretend that the `test_dataloader` is a new data loader."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "221df0f1",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_lightningmodule/datamodule_testloader.py"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "99e4f69a",
- "metadata": {},
- "source": [
- "Just as an internal check, if the model was loaded correctly, the test accuracy below should be identical to the test accuracy we saw earlier in the previous section."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b265a5b1",
- "metadata": {},
- "outputs": [],
- "source": [
- "test_acc = acc.compute()\n",
- "print(f'Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)')"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "68c3e00d",
- "metadata": {},
- "source": [
- "## Inspecting Failure Cases"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5cc44b41",
- "metadata": {},
- "source": [
- "- In practice, it is often informative to look at failure cases like wrong predictions for particular training instances as it can give us some insights into the model behavior and dataset.\n",
- "- Inspecting failure cases can sometimes reveal interesting patterns and even highlight dataset and labeling issues."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "12b82d48",
- "metadata": {},
- "outputs": [],
- "source": [
- "# In the case of ???, the class label mapping\n",
- "# ???\n",
- "class_dict = {???}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e298ac1b",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_lightningmodule/plot_failurecases_basic.py"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "dbb4ab75",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load ../code_lightningmodule/plot_confusion-matrix_basic.py"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/templates/pytorch_lightning/tune_classification_basic.py b/templates/pytorch_lightning/tune_classification_basic.py
index a0ab26f..a80fcc2 100644
--- a/templates/pytorch_lightning/tune_classification_basic.py
+++ b/templates/pytorch_lightning/tune_classification_basic.py
@@ -1,5 +1,8 @@
-import time
import argparse
+import time
+import subprocess
+import sys
+
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
@@ -13,6 +16,12 @@
from torch.utils.data.dataset import random_split
+def install(package):
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
+
+
+install("torchmetrics")
+
# Argparse helper
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)