From 8febf84697871007556aef8438dc8162bc9286d4 Mon Sep 17 00:00:00 2001 From: Hang Jung Ling Date: Thu, 12 Oct 2023 10:25:24 +0200 Subject: [PATCH] Add a code block to visualize curves stored in pretrained weight --- notebooks/camus_segmentation.ipynb | 111 ++++++++++++++++++++++++----- 1 file changed, 93 insertions(+), 18 deletions(-) diff --git a/notebooks/camus_segmentation.ipynb b/notebooks/camus_segmentation.ipynb index 0519ad1..eeacc39 100644 --- a/notebooks/camus_segmentation.ipynb +++ b/notebooks/camus_segmentation.ipynb @@ -66,8 +66,7 @@ "source": [ "# II. Dataset \n", "\n", - "Once the environment is successfully setup, download the CAMUS dataset by executing the following cell. The dataset will be downloaded to the `data/` folder. \n", - "> **ⓘ** Even if you have downloaded the dataset before, please execute the following cell to import the required packages." + "Once the environment is successfully setup, download the CAMUS dataset by executing the following cell. The dataset will be downloaded to the `data/` folder. " ] }, { @@ -111,6 +110,8 @@ "metadata": {}, "outputs": [], "source": [ + "from pathlib import Path\n", + "\n", "from sklearn.model_selection import train_test_split\n", "\n", "from src.utils.file_and_folder_operations import subdirs\n", @@ -358,7 +359,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Definition of optimizer and loss function" + "### Definition of optimizer and loss function\n", + "We will use the Adam optimizer. The loss function is the combination of the Dice and cross-entropy (CE) loss, which is a standard loss function for segmentation tasks.\n", + "> **ⓘ** The exact computation of the loss function is CE - Dice, which means its minimum value is -1 instead of 0 in the best case scenario." ] }, { @@ -403,6 +406,7 @@ "from monai.data import DataLoader\n", "from torch import nn\n", "from torch.utils.data import Dataset\n", + "from tqdm.auto import tqdm\n", "\n", "from src.utils.tensor_utils import sum_tensor\n", "\n", @@ -613,6 +617,7 @@ " {\n", " \"max_epochs\": max_epochs,\n", " \"current_epoch\": epoch + 1,\n", + " \"best_metric_epoch\": best_metric_epoch,\n", " \"train_loss\": epoch_train_loss_values,\n", " \"val_loss\": epoch_val_loss_values,\n", " \"epoch_val\": epoch_val,\n", @@ -629,6 +634,7 @@ " {\n", " \"max_epochs\": max_epochs,\n", " \"current_epoch\": epoch + 1,\n", + " \"best_metric_epoch\": best_metric_epoch,\n", " \"train_loss\": epoch_train_loss_values,\n", " \"val_loss\": epoch_val_loss_values,\n", " \"epoch_val\": epoch_val,\n", @@ -751,7 +757,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Load the best model" + "### Download the pre-trained model weight for 50 epochs and visualize learning curves and metrics over 50 epochs" ] }, { @@ -760,9 +766,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Set to True to download and use pretrained weights\n", - "# Set to False to use the best model you trained just before\n", - "use_pretrained_weights = False\n", + "from matplotlib.ticker import MaxNLocator\n", "\n", "# Path to the logging directory\n", "log_dir = \"../logs/camus_segmentation\"\n", @@ -773,21 +777,92 @@ "else:\n", " device = torch.device(\"cpu\")\n", "\n", - "# Load the best model\n", - "if use_pretrained_weights:\n", - " best_model = Path(log_dir) / \"best_model_50epochs.pth\"\n", - " if not best_model.is_file():\n", - " import urllib.request\n", + "# Load the best model weight\n", + "best_model = Path(log_dir) / \"best_model_50epochs.pth\"\n", + "if not best_model.is_file():\n", + " import urllib.request\n", "\n", - " urllib.request.urlretrieve(\n", - " \"https://www.creatis.insa-lyon.fr/~bernard/camus/best_model_50epochs.pth\",\n", - " str(best_model),\n", - " )\n", + " urllib.request.urlretrieve(\n", + " \"https://www.creatis.insa-lyon.fr/~bernard/camus/best_model_50epochs.pth\",\n", + " str(best_model),\n", + " )\n", + "\n", + "# Load the weight\n", + "weight = torch.load(best_model)\n", + "epoch_train_loss_values = weight[\"train_loss\"]\n", + "epoch_val_loss_values = weight[\"val_loss\"]\n", + "epoch_val = weight[\"epoch_val\"]\n", + "metric_values = weight[\"metric_values\"]\n", + "metric_per_class = weight[\"metric_per_class\"]\n", + "max_epochs = weight[\"max_epochs\"]\n", + "best_metric_epoch = weight[\"best_metric_epoch\"]\n", + "\n", + "# Plot the training losses, validation losses and validation dice over epochs\n", + "trains_epoch = list(range(1, max_epochs + 1, 1))\n", + "vals_epochs = epoch_val\n", + "\n", + "plt.figure(\"train\", (16, 8))\n", + "ax = plt.subplot(1, 3, 1)\n", + "plt.title(\"Epoch Average Loss\")\n", + "plt.xlabel(\"Epoch\")\n", + "ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Set integer ticks for x-axis\n", + "plt.plot(trains_epoch, epoch_train_loss_values, color=\"red\", label=\"Train\")\n", + "plt.plot(vals_epochs, epoch_val_loss_values, color=\"blue\", label=\"Val\")\n", + "# Add a vertical line at the best model epoch\n", + "plt.axvline(best_metric_epoch, color=\"gray\", linestyle=\"--\")\n", + "plt.text(\n", + " best_metric_epoch + 0.5,\n", + " min(epoch_train_loss_values) * 4 / 5,\n", + " f\"Best model epoch = {best_metric_epoch}\",\n", + " rotation=90,\n", + ")\n", + "plt.legend(loc=\"upper right\")\n", + "ax = plt.subplot(1, 3, 2)\n", + "plt.title(\"Val Mean Dice\")\n", + "plt.xlabel(\"Epoch\")\n", + "ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Set integer ticks for x-axis\n", + "plt.plot(vals_epochs, metric_values, color=\"green\")\n", + "plt.axvline(best_metric_epoch, color=\"gray\", linestyle=\"--\")\n", + "\n", + "ax = plt.subplot(1, 3, 3)\n", + "plt.title(\"Val Mean Dice per Class\")\n", + "plt.xlabel(\"Epoch\")\n", + "ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Set integer ticks for x-axis\n", + "legend_metric = [\"left ventricle\", \"myocardium\", \"left atrium\"]\n", + "for i in range(1, num_classes, 1):\n", + " plt.plot(vals_epochs, metric_per_class[f\"metric/{i}\"], label=legend_metric[i - 1])\n", + "plt.axvline(best_metric_epoch, color=\"gray\", linestyle=\"--\")\n", + "plt.legend(loc=\"upper left\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the pre-trained model weight into the U-Net" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set to True to use the pretrained model\n", + "# Set to False to use the model you just trained\n", + "use_pretrained = True\n", + "\n", + "if use_pretrained:\n", + " weight = torch.load(Path(log_dir) / \"best_model_50epochs.pth\")\n", "else:\n", - " best_model = Path(log_dir) / \"best_metric_model.pth\"\n", + " weight = torch.load(Path(log_dir) / \"best_metric_model.pth\")\n", "\n", - "unet.load_state_dict(torch.load(best_model)[\"model_state_dict\"])\n", + "# Load the weight into the U-Net\n", + "unet.load_state_dict(weight[\"model_state_dict\"])\n", + "# Move the U-Net to the correct device\n", "unet.to(device)\n", + "# Put the U-Net in evaluation mode\n", "unet.eval()" ] },