Skip to content

Commit

Permalink
Add a code block to visualize curves stored in pretrained weight
Browse files Browse the repository at this point in the history
  • Loading branch information
HangJung97 committed Oct 12, 2023
1 parent 329e15e commit 8febf84
Showing 1 changed file with 93 additions and 18 deletions.
111 changes: 93 additions & 18 deletions notebooks/camus_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@
"source": [
"# II. Dataset <a class=\"anchor\" id=\"dataset\"></a>\n",
"\n",
"Once the environment is successfully setup, download the CAMUS dataset by executing the following cell. The dataset will be downloaded to the `data/` folder. \n",
"> **&#9432;** Even if you have downloaded the dataset before, please execute the following cell to import the required packages."
"Once the environment is successfully setup, download the CAMUS dataset by executing the following cell. The dataset will be downloaded to the `data/` folder. "
]
},
{
Expand Down Expand Up @@ -111,6 +110,8 @@
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from src.utils.file_and_folder_operations import subdirs\n",
Expand Down Expand Up @@ -358,7 +359,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Definition of optimizer and loss function"
"### Definition of optimizer and loss function\n",
"We will use the Adam optimizer. The loss function is the combination of the Dice and cross-entropy (CE) loss, which is a standard loss function for segmentation tasks.\n",
"> **&#9432;** The exact computation of the loss function is CE - Dice, which means its minimum value is -1 instead of 0 in the best case scenario."
]
},
{
Expand Down Expand Up @@ -403,6 +406,7 @@
"from monai.data import DataLoader\n",
"from torch import nn\n",
"from torch.utils.data import Dataset\n",
"from tqdm.auto import tqdm\n",
"\n",
"from src.utils.tensor_utils import sum_tensor\n",
"\n",
Expand Down Expand Up @@ -613,6 +617,7 @@
" {\n",
" \"max_epochs\": max_epochs,\n",
" \"current_epoch\": epoch + 1,\n",
" \"best_metric_epoch\": best_metric_epoch,\n",
" \"train_loss\": epoch_train_loss_values,\n",
" \"val_loss\": epoch_val_loss_values,\n",
" \"epoch_val\": epoch_val,\n",
Expand All @@ -629,6 +634,7 @@
" {\n",
" \"max_epochs\": max_epochs,\n",
" \"current_epoch\": epoch + 1,\n",
" \"best_metric_epoch\": best_metric_epoch,\n",
" \"train_loss\": epoch_train_loss_values,\n",
" \"val_loss\": epoch_val_loss_values,\n",
" \"epoch_val\": epoch_val,\n",
Expand Down Expand Up @@ -751,7 +757,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the best model"
"### Download the pre-trained model weight for 50 epochs and visualize learning curves and metrics over 50 epochs"
]
},
{
Expand All @@ -760,9 +766,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Set to True to download and use pretrained weights\n",
"# Set to False to use the best model you trained just before\n",
"use_pretrained_weights = False\n",
"from matplotlib.ticker import MaxNLocator\n",
"\n",
"# Path to the logging directory\n",
"log_dir = \"../logs/camus_segmentation\"\n",
Expand All @@ -773,21 +777,92 @@
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"# Load the best model\n",
"if use_pretrained_weights:\n",
" best_model = Path(log_dir) / \"best_model_50epochs.pth\"\n",
" if not best_model.is_file():\n",
" import urllib.request\n",
"# Load the best model weight\n",
"best_model = Path(log_dir) / \"best_model_50epochs.pth\"\n",
"if not best_model.is_file():\n",
" import urllib.request\n",
"\n",
" urllib.request.urlretrieve(\n",
" \"https://www.creatis.insa-lyon.fr/~bernard/camus/best_model_50epochs.pth\",\n",
" str(best_model),\n",
" )\n",
" urllib.request.urlretrieve(\n",
" \"https://www.creatis.insa-lyon.fr/~bernard/camus/best_model_50epochs.pth\",\n",
" str(best_model),\n",
" )\n",
"\n",
"# Load the weight\n",
"weight = torch.load(best_model)\n",
"epoch_train_loss_values = weight[\"train_loss\"]\n",
"epoch_val_loss_values = weight[\"val_loss\"]\n",
"epoch_val = weight[\"epoch_val\"]\n",
"metric_values = weight[\"metric_values\"]\n",
"metric_per_class = weight[\"metric_per_class\"]\n",
"max_epochs = weight[\"max_epochs\"]\n",
"best_metric_epoch = weight[\"best_metric_epoch\"]\n",
"\n",
"# Plot the training losses, validation losses and validation dice over epochs\n",
"trains_epoch = list(range(1, max_epochs + 1, 1))\n",
"vals_epochs = epoch_val\n",
"\n",
"plt.figure(\"train\", (16, 8))\n",
"ax = plt.subplot(1, 3, 1)\n",
"plt.title(\"Epoch Average Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Set integer ticks for x-axis\n",
"plt.plot(trains_epoch, epoch_train_loss_values, color=\"red\", label=\"Train\")\n",
"plt.plot(vals_epochs, epoch_val_loss_values, color=\"blue\", label=\"Val\")\n",
"# Add a vertical line at the best model epoch\n",
"plt.axvline(best_metric_epoch, color=\"gray\", linestyle=\"--\")\n",
"plt.text(\n",
" best_metric_epoch + 0.5,\n",
" min(epoch_train_loss_values) * 4 / 5,\n",
" f\"Best model epoch = {best_metric_epoch}\",\n",
" rotation=90,\n",
")\n",
"plt.legend(loc=\"upper right\")\n",
"ax = plt.subplot(1, 3, 2)\n",
"plt.title(\"Val Mean Dice\")\n",
"plt.xlabel(\"Epoch\")\n",
"ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Set integer ticks for x-axis\n",
"plt.plot(vals_epochs, metric_values, color=\"green\")\n",
"plt.axvline(best_metric_epoch, color=\"gray\", linestyle=\"--\")\n",
"\n",
"ax = plt.subplot(1, 3, 3)\n",
"plt.title(\"Val Mean Dice per Class\")\n",
"plt.xlabel(\"Epoch\")\n",
"ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Set integer ticks for x-axis\n",
"legend_metric = [\"left ventricle\", \"myocardium\", \"left atrium\"]\n",
"for i in range(1, num_classes, 1):\n",
" plt.plot(vals_epochs, metric_per_class[f\"metric/{i}\"], label=legend_metric[i - 1])\n",
"plt.axvline(best_metric_epoch, color=\"gray\", linestyle=\"--\")\n",
"plt.legend(loc=\"upper left\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the pre-trained model weight into the U-Net"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set to True to use the pretrained model\n",
"# Set to False to use the model you just trained\n",
"use_pretrained = True\n",
"\n",
"if use_pretrained:\n",
" weight = torch.load(Path(log_dir) / \"best_model_50epochs.pth\")\n",
"else:\n",
" best_model = Path(log_dir) / \"best_metric_model.pth\"\n",
" weight = torch.load(Path(log_dir) / \"best_metric_model.pth\")\n",
"\n",
"unet.load_state_dict(torch.load(best_model)[\"model_state_dict\"])\n",
"# Load the weight into the U-Net\n",
"unet.load_state_dict(weight[\"model_state_dict\"])\n",
"# Move the U-Net to the correct device\n",
"unet.to(device)\n",
"# Put the U-Net in evaluation mode\n",
"unet.eval()"
]
},
Expand Down

0 comments on commit 8febf84

Please sign in to comment.