diff --git a/Colab_notebooks/U-Net_2D_Multilabel_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U-Net_2D_Multilabel_ZeroCostDL4Mic.ipynb
index 9546e9a..33f7e44 100644
--- a/Colab_notebooks/U-Net_2D_Multilabel_ZeroCostDL4Mic.ipynb
+++ b/Colab_notebooks/U-Net_2D_Multilabel_ZeroCostDL4Mic.ipynb
@@ -157,7 +157,7 @@
"source": [
"#@markdown ##Play to install U-Net dependencies\n",
"# Install packages which are not included in Google Colab\n",
- "!pip install data\n",
+ "!pip install -q data\n",
"!pip install -q tifffile # contains tools to operate tiff-files\n",
"!pip install -q wget\n",
"!pip install -q fpdf2\n",
@@ -165,6 +165,7 @@
"!pip install -q zarr\n",
"!pip install -q imagecodecs\n",
"!pip install -q bioimageio.core==0.6.9\n",
+ "!pip install -q tf-keras==2.15\n",
"!pip install -q tensorflow==2.15"
]
},
@@ -209,7 +210,6 @@
"Notebook_version = '2.1.3'\n",
"Network = 'U-Net (2D) multilabel'\n",
"\n",
- "\n",
"from builtins import any as b_any\n",
"\n",
"def get_requirements_path():\n",
@@ -692,7 +692,7 @@
" model.summary()\n",
"\n",
" if(pretrained_weights):\n",
- " \tmodel.load_weights(pretrained_weights);\n",
+ " model.load_weights(pretrained_weights)\n",
"\n",
" return model\n",
"\n",
@@ -803,7 +803,7 @@
" \"\"\"\n",
" if author_input_info.strip() == '':\n",
" return None\n",
- " \n",
+ "\n",
" auth_order = ['name', 'affiliation', 'email', 'orcid', 'github_user']\n",
" auth_dict = {}\n",
"\n",
@@ -815,7 +815,7 @@
" else:\n",
" auth_dict[auth_order[i]] = auth_info_split[i].strip()\n",
"\n",
- " return bioimageio_spec.Author(**auth_dict) \n",
+ " return bioimageio_spec.Author(**auth_dict)\n",
"\n",
"def make_maintainer(maintainer_input_info: str):\n",
" \"\"\"\n",
@@ -838,7 +838,7 @@
" else:\n",
" maint_dict[maint_order[i]] = maint_info_split[i].strip()\n",
"\n",
- " return bioimageio_spec.Maintainer(**maint_dict) \n",
+ " return bioimageio_spec.Maintainer(**maint_dict)\n",
"\n",
"\n",
"# -------------- Other definitions -----------\n",
@@ -961,7 +961,7 @@
" if Use_Default_Advanced_Parameters:\n",
" pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n",
" pdf.cell(200, 5, txt='The following parameters were used for training:')\n",
- " pdf.ln(1)\n",
+ " pdf.ln(3)\n",
" html = \"\"\"\n",
"
\n",
" \n",
@@ -969,36 +969,36 @@
" Value | \n",
"
\n",
" \n",
- " number_of_epochs | \n",
- " {0} | \n",
+ " number_of_epochs | \n",
+ " {0} | \n",
"
\n",
" \n",
- " patch_size | \n",
- " {1} | \n",
+ " patch_size | \n",
+ " {1} | \n",
"
\n",
" \n",
- " batch_size | \n",
- " {2} | \n",
+ " batch_size | \n",
+ " {2} | \n",
"
\n",
" \n",
- " number_of_steps | \n",
- " {3} | \n",
+ " number_of_steps | \n",
+ " {3} | \n",
"
\n",
" \n",
- " percentage_validation | \n",
- " {4} | \n",
+ " percentage_validation | \n",
+ " {4} | \n",
"
\n",
" \n",
- " initial_learning_rate | \n",
- " {5} | \n",
+ " initial_learning_rate | \n",
+ " {5} | \n",
"
\n",
" \n",
- " pooling_steps | \n",
- " {6} | \n",
+ " pooling_steps | \n",
+ " {6} | \n",
"
\n",
" \n",
- " min_fraction | \n",
- " {7} | \n",
+ " min_fraction | \n",
+ " {7} | \n",
"
\n",
" \"\"\".format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)\n",
" pdf.write_html(html)\n",
@@ -1204,21 +1204,10 @@
"source": [
"#@markdown ##Run this cell to check if you have GPU access\n",
"\n",
- "if tf.test.gpu_device_name()=='':\n",
- " print('You do not have GPU access.')\n",
- " print('Did you change your runtime ?')\n",
- " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n",
- " print('Expect slow performance. To access GPU try reconnecting later')\n",
- "\n",
- "else:\n",
- " print('You have GPU access')\n",
- " !nvidia-smi\n",
- "\n",
- "# from tensorflow.python.client import device_lib\n",
- "# device_lib.list_local_devices()\n",
- "\n",
- "# print the tensorflow version\n",
- "print('Tensorflow version is ' + str(tf.__version__))\n"
+ "!if type nvidia-smi >/dev/null 2>&1; then \\\n",
+ " echo \"You have GPU access\"; nvidia-smi; \\\n",
+ " else \\\n",
+ " echo -e \"You do not have GPU access.\\nDid you change your runtime?\\nIf the runtime setting is correct then Google did not allocate a GPU for your session\\nExpect slow performance. To access GPU try reconnecting later\"; fi"
]
},
{
@@ -1231,7 +1220,7 @@
"---\n",
" To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n",
"\n",
- " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive.\n",
+ " Play the cell below to mount your Google Drive. Click on **Connect to Google Drive** and a window will pop up. You will need to sign in tour Google Account, follow the steps and click **Continue**. This will give Colab access to the data on the drive.\n",
"\n",
" Once this is done, your data are available in the **Files** tab on the top left of notebook."
]
@@ -1247,13 +1236,9 @@
"source": [
"#@markdown ##Play the cell to connect your Google Drive to Colab\n",
"\n",
- "#@markdown * Click on the URL.\n",
- "\n",
- "#@markdown * Sign in your Google Account.\n",
+ "#@markdown * Click on **Connect to Google Drive**.\n",
"\n",
- "#@markdown * Copy the authorization code.\n",
- "\n",
- "#@markdown * Enter the authorization code.\n",
+ "#@markdown * A new window, will pop up. Sign in your Google Account.\n",
"\n",
"#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\".\n",
"\n",
@@ -1592,7 +1577,7 @@
"\n",
" **You do not need to run this section if you want to train a network from scratch**.\n",
"\n",
- " This option allows you to use pre-trained models from the [BioImage Model Zoo](https://bioimage.io/#/) and fine-tune them to analyse new data. Choose `bioimageio_model` and provide the ID in `bioimageio_model_id` (e.g., \"creative-panda\" or \"10.5281/zenodo.5817052\").\n",
+ " This option allows you to use pre-trained models from the [BioImage Model Zoo](https://bioimage.io/#/) and fine-tune them to analyse new data. Choose `bioimageio_model` and provide the ID in `bioimageio_model_id` (e.g., \"placid-llama\" or \"10.5281/zenodo.5817052\").\n",
"\n",
" This option also allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. Choose `Model_from_file` and provide the `pretrained_model_path`.\n",
"\n",
@@ -1609,6 +1594,8 @@
"outputs": [],
"source": [
"# @markdown ##Loading weights from a pre-trained network\n",
+ "from bioimageio.core import load_description\n",
+ "from bioimageio.spec.utils import download\n",
"\n",
"Use_pretrained_model = False #@param {type:\"boolean\"}\n",
"pretrained_model_choice = \"BioImage Model Zoo\" #@param [\"Model_from_file\", \"BioImage Model Zoo\"]\n",
@@ -1629,33 +1616,32 @@
" h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n",
" qc_path = os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')\n",
" elif pretrained_model_choice == \"BioImage Model Zoo\":\n",
- "\n",
- " model_spec = load_resource_description(bioimageio_model_id)\n",
- " if \"keras_hdf5\" not in biomodel.weights:\n",
+ " model_spec = load_description(bioimageio_model_id)\n",
+ " if \"keras_hdf5\" not in model_spec.weights.model_fields_set:\n",
" print(\"Invalid bioimageio model\")\n",
" h5_file_path = \"no-model\"\n",
" qc_path = \"no-qc\"\n",
" else:\n",
- " h5_file_path = str(biomodel.weights[\"keras_hdf5\"].source)\n",
+ " h5_file_path = str(download(model_spec.weights.keras_hdf5.source).path)\n",
" try:\n",
- " attachments = biomodel.attachments.files\n",
- " qc_path = [fname for fname in attachments if fname.endswith(\"training_evaluation.csv\")][0]\n",
- " qc_path = os.path.join(base_path + \"//bioimageio_pretrained_model\", qc_path)\n",
+ " attachments = model_spec.attachments.files\n",
+ " qc_path = str(download([fname for fname in attachments if str(fname).endswith(\"training_evaluation.csv\")][0]).path)\n",
+ " # qc_path = os.path.join(base_path + \"//bioimageio_pretrained_model\", qc_path)\n",
" except Exception:\n",
" qc_path = \"no-qc\"\n",
"\n",
"# --------------------- Check the model exist ------------------------\n",
- "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled,\n",
+ "\n",
" if not os.path.exists(h5_file_path):\n",
+ " # If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled,\n",
" print(R+'WARNING: pretrained model does not exist')\n",
" Use_pretrained_model = False\n",
+ " else:\n",
+ " # If the model path contains a pretrain model, we load the training rate\n",
"\n",
- "\n",
- "# If the model path contains a pretrain model, we load the training rate,\n",
- " if os.path.exists(h5_file_path):\n",
- "#Here we check if the learning rate can be loaded from the quality control folder\n",
- " # if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n",
" if os.path.exists(qc_path):\n",
+ " #Here we check if the learning rate can be loaded from the quality control folder\n",
+ " # if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n",
"\n",
" # with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n",
" with open(qc_path,'r') as csvfile:\n",
@@ -1663,7 +1649,7 @@
" #print(csvRead)\n",
"\n",
" if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n",
- " print(\"pretrained network learning rate found\")\n",
+ " print(\"A 'learning rate' attribute was found on provided pre-trained models.\")\n",
" #find the last learning rate\n",
" lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n",
" #Find the learning rate corresponding to the lowest validation loss\n",
@@ -1672,28 +1658,29 @@
" bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n",
"\n",
" if Weights_choice == \"last\":\n",
- " print('Last learning rate: '+str(lastLearningRate))\n",
- "\n",
- " if Weights_choice == \"best\":\n",
- " print('Learning rate of best validation loss: '+str(bestLearningRate))\n",
+ " print(f'You will be loading \\033[1mlast\\033[0m learning rate: {lastLearningRate}')\n",
+ " elif Weights_choice == \"best\":\n",
+ " print(f'You will be loading the learning rate of \\033[1mbest\\033[0m validation loss: {bestLearningRate}')\n",
+ " else:\n",
+ " #if the column does not exist, then initial learning rate is used instead\n",
+ " print(f\"{bcolors.WARNING}WARNING: The learning rate cannot be identified from the pretrained network{W}\")\n",
+ " print(f\"{bcolors.WARNING}Default learning rate of {initial_learning_rate} will be used instead{W}\")\n",
"\n",
- " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n",
" bestLearningRate = initial_learning_rate\n",
" lastLearningRate = initial_learning_rate\n",
- " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n",
- "\n",
- "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n",
- " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n",
- " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n",
+ " else:\n",
+ " #Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n",
+ " print(f\"{bcolors.WARNING}Sorry, 'training_evaluation.csv' does not exists or was not correctly loaded.{W}\")\n",
+ " print(f\"{bcolors.WARNING}Default learning rate of {initial_learning_rate} will be used instead{W}\")\n",
" bestLearningRate = initial_learning_rate\n",
" lastLearningRate = initial_learning_rate\n",
"\n",
"\n",
"# Display info about the pretrained model to be loaded (or not)\n",
"if Use_pretrained_model:\n",
- " print('Weights found in:')\n",
- " print(h5_file_path)\n",
- " print('will be loaded prior to training.')\n",
+ " print('-'*50)\n",
+ " print(f'Weights found in: {h5_file_path}')\n",
+ " print('They will be loaded prior to training.')\n",
"\n",
"else:\n",
" print(R+'No pretrained network will be used.')\n",
@@ -1780,15 +1767,17 @@
"# Load the pretrained weights\n",
"if Use_pretrained_model:\n",
" try:\n",
+ " print(\"Weights correctly loaded.\")\n",
" model.load_weights(h5_file_path)\n",
" except:\n",
- " print(bcolors.WARNING + \"The pretrained model could not be loaded as the configuration of the network is different.\")\n",
- " print(\"Please, read the model specifications and check the parameters in Section 3.1\" + W)\n",
+ " print(f\"{bcolors.WARNING}The pretrained model could not be loaded as the configuration of the network is different.\")\n",
+ " print(\"Please, read the model specifications and check the parameters in Section 3.1\")\n",
+ " print(f\"It might probably be the pooling steps attribute, please take a look to it.{W}\")\n",
"\n",
- "# except:\n",
- "# print(\"The pretrained model could not be loaded. Please, check the parameters of the pre-trained model architecture.\")\n",
"config_model= model.optimizer.get_config()\n",
- "print(config_model)\n",
+ "print(\"Configuration of model's optimizer:\")\n",
+ "for k,v in config_model.items():\n",
+ " print(f\"{k} : {v}\")\n",
"\n",
"\n",
"# ------------------ Failsafes ------------------\n",
@@ -2252,7 +2241,7 @@
"if data_from_bioimage_model_zoo:\n",
" training_data = {'id' : training_data_ID}\n",
"else:\n",
- " training_data = None \n",
+ " training_data = None\n",
"\n",
"# create the author/maintainer/packager spec input\n",
"author_1_spec = make_author(Trained_model_author_1)\n",
@@ -2311,7 +2300,7 @@
"\n",
"# load the input image, crop it if necessary, and save as numpy file\n",
"# The crop will be centered to get an image with some content.\n",
- "input_img = io.imread(fileID, as_gray = True)\n",
+ "input_img = io.imread(fileID, as_gray = True).astype(np.float32)\n",
"assert input_img.ndim == 2,'Example input image is not a 2D grayscale image. Please, provide a 2D grayscale image.'\n",
"\n",
"# batch should never be constrained\n",
@@ -2331,7 +2320,6 @@
"test_img = test_img[x_size : x_size + shape[1],\n",
" y_size : y_size + shape[2]]\n",
"\n",
- "\n",
"# Save the test image\n",
"test_input_path = os.path.join(output_root, \"test_input.npy\")\n",
"np.save(test_input_path, test_img[None, ..., None])\n",
@@ -2358,22 +2346,24 @@
" channel_names.append(f'channel{idx}')\n",
"\n",
"# create the input tensor\n",
- "input_tensor = bioimageio_spec.InputTensorDescr(id=bioimageio_spec.TensorId('input0'), \n",
+ "input_tensor = bioimageio_spec.InputTensorDescr(id=bioimageio_spec.TensorId('input0'),\n",
" description= 'This is the test input tensor created from the example image.',\n",
- " axes=[bioimageio_spec.BatchAxis(id='batch', description='', type='batch', size=None), \n",
- " bioimageio_spec.SpaceInputAxis(size=bioimageio_spec.ParameterizedSize(min=min_size, step=step_size), id='y', description='', type='space', unit=None, scale=1.0, concatenable=False), \n",
- " bioimageio_spec.SpaceInputAxis(size=bioimageio_spec.ParameterizedSize(min=min_size, step=step_size), id='x', description='', type='space', unit=None, scale=1.0, concatenable=False), \n",
+ " axes=[bioimageio_spec.BatchAxis(id='batch', description='', type='batch', size=None),\n",
+ " bioimageio_spec.SpaceInputAxis(size=bioimageio_spec.ParameterizedSize(min=min_size, step=step_size), id='y', description='', type='space', unit=None, scale=1.0, concatenable=False),\n",
+ " bioimageio_spec.SpaceInputAxis(size=bioimageio_spec.ParameterizedSize(min=min_size, step=step_size), id='x', description='', type='space', unit=None, scale=1.0, concatenable=False),\n",
" bioimageio_spec.ChannelAxis(id='channel', description='', type='channel', channel_names=['channel0'])],\n",
- " test_tensor = bioimageio_spec.FileDescr(source = test_input_path), \n",
- " preprocessing = [bioimageio_spec.ScaleRangeDescr(kwargs=bioimageio_spec.ScaleRangeKwargs(axes = ['x','y'], min_percentile = 1.0 , max_percentile = 99.8) )],\n",
+ " test_tensor = bioimageio_spec.FileDescr(source = test_input_path),\n",
+ " preprocessing = [bioimageio_spec.EnsureDtypeDescr(kwargs=bioimageio_spec.EnsureDtypeKwargs(dtype=\"float32\")),\n",
+ " bioimageio_spec.ScaleRangeDescr(kwargs=bioimageio_spec.ScaleRangeKwargs(axes = ['x','y'], min_percentile = 1.0 , max_percentile = 99.8) ),\n",
+ " ],\n",
" )\n",
- " \n",
+ "\n",
"\n",
"\n",
"# create the output tensor\n",
- "output_tensor = bioimageio_spec.OutputTensorDescr( axes=[bioimageio_spec.BatchAxis(id='batch', description='', type='batch', size=None), \n",
- " bioimageio_spec.SpaceOutputAxis(size=bioimageio_spec.SizeReference(tensor_id=bioimageio_spec.TensorId('input0'), axis_id='y', offset=0), id='y', description='', type='space', unit=None, scale=1.0), \n",
- " bioimageio_spec.SpaceOutputAxis( size=bioimageio_spec.SizeReference(tensor_id=bioimageio_spec.TensorId('input0'), axis_id='x', offset=0), id='x', description='', type='space', unit=None, scale=1.0), \n",
+ "output_tensor = bioimageio_spec.OutputTensorDescr( axes=[bioimageio_spec.BatchAxis(id='batch', description='', type='batch', size=None),\n",
+ " bioimageio_spec.SpaceOutputAxis(size=bioimageio_spec.SizeReference(tensor_id=bioimageio_spec.TensorId('input0'), axis_id='y', offset=0), id='y', description='', type='space', unit=None, scale=1.0),\n",
+ " bioimageio_spec.SpaceOutputAxis( size=bioimageio_spec.SizeReference(tensor_id=bioimageio_spec.TensorId('input0'), axis_id='x', offset=0), id='x', description='', type='space', unit=None, scale=1.0),\n",
" bioimageio_spec.ChannelAxis(id='channel', description='', type='channel', channel_names=channel_names)],\n",
" test_tensor = bioimageio_spec.FileDescr(source = test_output_path) )\n",
"\n",
@@ -2382,7 +2372,7 @@
"qc_path = os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv')\n",
"if os.path.exists(qc_path):\n",
" attachments.append(FileDescr(source = qc_path))\n",
- " \n",
+ "\n",
"# Include a post-processing deepImageJ macro\n",
"macro = \"Contours2InstanceSegmentation.ijm\"\n",
"url = f\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/{macro}\"\n",
@@ -2431,7 +2421,7 @@
" packaged_by=[packager_spec],\n",
" weights=unet_weights,\n",
" training_data=training_data,\n",
- " \n",
+ "\n",
" )\n",
"\n",
"\n",
@@ -2439,10 +2429,10 @@
"summary = bioimageio_core.test_model(model_description, weight_format=\"keras_hdf5\")\n",
"summary.display()\n",
"\n",
- "success = summary.status == \"passed\" \n",
+ "success = summary.status == \"passed\"\n",
"\n",
"save_bioimageio_package(model_description, output_path=Path(output_path))\n",
- " \n",
+ "\n",
"if success:\n",
" print(\"The bioimage.io model was successfully exported to\", output_path)\n",
"else:\n",
@@ -2600,7 +2590,8 @@
"---\n",
"**v2.1.3**: \n",
"\n",
- "* Updated Bioimage.io model export to latest version (core-0.6.9, spec-0.5.3.2)\n",
+ "* Updated Bioimage.IO model export to latest version (core-0.6.9, spec-0.5.3.2)\n",
+ "* Fixed model importation from Bioimage.IO\n",
"* Fixed Tensorflow version to 2.15\n",
"* Bug fixes\n",
"\n",
@@ -2631,11 +2622,11 @@
"metadata": {
"accelerator": "GPU",
"colab": {
+ "gpuType": "T4",
"provenance": []
},
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
+ "display_name": "Python 3",
"name": "python3"
},
"language_info": {
diff --git a/Colab_notebooks/U_Net_2D_Multilabel_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U_Net_2D_Multilabel_ZeroCostDL4Mic.ipynb
deleted file mode 100644
index 33f7e44..0000000
--- a/Colab_notebooks/U_Net_2D_Multilabel_ZeroCostDL4Mic.ipynb
+++ /dev/null
@@ -1,2647 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "IkSguVy8Xv83"
- },
- "source": [
- "# **U-Net (2D) for multilabel segmentation (semantic segmentation)**\n",
- "---\n",
- "\n",
- "U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n",
- "\n",
- "The main difference between this U-Net and the original one, is that the output is a semantic mask rather than a binary mask: It allows the segmentation of different kind of structures, objects or tissues present in the image. by labelling pixels with different values (0, 1, 2, ...) rather than just 0 and 1 (binary segmentation). So, to use this notebook you need to provide a mask image where each type of label has a different pixel value rather than just 0 and 1. For example 0 = background, 1 = cytoplasm, 2 = lumen and 3 = nuclei.\n",
- "\n",
- " **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n",
- "\n",
- "---\n",
- "*Disclaimer*:\n",
- "\n",
- "This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n",
- "\n",
- "This notebook is largely based on the papers:\n",
- "\n",
- "**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n",
- "\n",
- "and\n",
- "\n",
- "**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n",
- "(https://www.nature.com/articles/s41592-018-0261-2)\n",
- "And source code found in: https://github.com/zhixuhao/unet by *Zhixuhao*\n",
- "\n",
- "**Please also cite this original paper when using or developing this notebook.**"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "jWAz2i7RdxUV"
- },
- "source": [
- "# **How to use this notebook?**\n",
- "\n",
- "---\n",
- "\n",
- "Video describing how to use our notebooks are available on youtube:\n",
- " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n",
- " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n",
- "\n",
- "\n",
- "\n",
- "---\n",
- "### **Structure of a notebook**\n",
- "\n",
- "The notebook contains two types of cell: \n",
- "\n",
- "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n",
- "\n",
- "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n",
- "\n",
- "---\n",
- "### **Table of contents, Code snippets** and **Files**\n",
- "\n",
- "On the top left side of the notebook you find three tabs which contain from top to bottom:\n",
- "\n",
- "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n",
- "\n",
- "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n",
- "\n",
- "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here.\n",
- "\n",
- "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n",
- "\n",
- "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n",
- "\n",
- "---\n",
- "### **Making changes to the notebook**\n",
- "\n",
- "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n",
- "\n",
- "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n",
- "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gKDLkLWUd-YX"
- },
- "source": [
- "# **0. Before getting started**\n",
- "---\n",
- "\n",
- "Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n",
- "\n",
- "For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n",
- "\n",
- "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n",
- "\n",
- "Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n",
- "\n",
- "Here's a common data structure that can work:\n",
- "* Experiment A\n",
- " - **Training dataset**\n",
- " - Training_source\n",
- " - img_1.tif, img_2.tif, ...\n",
- " - Training_target\n",
- " - img_1.tif, img_2.tif, ...\n",
- " - **Quality control dataset**\n",
- " - Training_source\n",
- " - img_1.tif, img_2.tif\n",
- " - Training_target\n",
- " - img_1.tif, img_2.tif\n",
- " - **Data to be predicted**\n",
- " - **Results**\n",
- "\n",
- "---\n",
- "**Important note**\n",
- "\n",
- "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n",
- "\n",
- "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n",
- "\n",
- "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n",
- "---"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "AdN8B91xZO0x"
- },
- "source": [
- "# **1. Install U-Net dependencies**\n",
- "---\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "UGWnGOFsf07b"
- },
- "source": [
- "## **1.1. Install key dependencies**\n",
- "---\n",
- ""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "uc0haIa-fZiG"
- },
- "outputs": [],
- "source": [
- "#@markdown ##Play to install U-Net dependencies\n",
- "# Install packages which are not included in Google Colab\n",
- "!pip install -q data\n",
- "!pip install -q tifffile # contains tools to operate tiff-files\n",
- "!pip install -q wget\n",
- "!pip install -q fpdf2\n",
- "!pip install -q PTable # Nice tables\n",
- "!pip install -q zarr\n",
- "!pip install -q imagecodecs\n",
- "!pip install -q bioimageio.core==0.6.9\n",
- "!pip install -q tf-keras==2.15\n",
- "!pip install -q tensorflow==2.15"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "I4O5zctbf4Gb"
- },
- "source": [
- "\n",
- "## **1.2. Restart your runtime**\n",
- "---\n",
- "\n",
- "\n",
- "\n",
- "** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n",
- "\n",
- " "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "iiX3Ly-7gA5h"
- },
- "source": [
- "## **1.3. Load key dependencies**\n",
- "---\n",
- ""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "fq21zJVFNASx"
- },
- "outputs": [],
- "source": [
- "from __future__ import print_function\n",
- "Notebook_version = '2.1.3'\n",
- "Network = 'U-Net (2D) multilabel'\n",
- "\n",
- "from builtins import any as b_any\n",
- "\n",
- "def get_requirements_path():\n",
- " # Store requirements file in 'contents' directory\n",
- " current_dir = os.getcwd()\n",
- " dir_count = current_dir.count('/') - 1\n",
- " path = '../' * (dir_count) + 'requirements.txt'\n",
- " return path\n",
- "\n",
- "def filter_files(file_list, filter_list):\n",
- " filtered_list = []\n",
- " for fname in file_list:\n",
- " if b_any(fname.split('==')[0] in s for s in filter_list):\n",
- " filtered_list.append(fname)\n",
- " return filtered_list\n",
- "\n",
- "def build_requirements_file(before, after):\n",
- " path = get_requirements_path()\n",
- "\n",
- " # Exporting requirements.txt for local run\n",
- " !pip freeze > $path\n",
- "\n",
- " # Get minimum requirements file\n",
- " df = pd.read_csv(path)\n",
- " mod_list = [m.split('.')[0] for m in after if not m in before]\n",
- " req_list_temp = df.values.tolist()\n",
- " req_list = [x[0] for x in req_list_temp]\n",
- "\n",
- " # Replace with package name and handle cases where import name is different to module name\n",
- " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n",
- " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n",
- " filtered_list = filter_files(req_list, mod_replace_list)\n",
- "\n",
- " file=open(path,'w')\n",
- " for item in filtered_list:\n",
- " file.writelines(item)\n",
- "\n",
- " file.close()\n",
- "\n",
- "import sys\n",
- "before = [str(m) for m in sys.modules]\n",
- "\n",
- "#@markdown ##Load key U-Net dependencies\n",
- "\n",
- "#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n",
- "#only the data library needs to be additionally installed.\n",
- "import tensorflow as tf\n",
- "# print(tensorflow.__version__)\n",
- "# print(\"Tensorflow enabled.\")\n",
- "\n",
- "\n",
- "# Keras imports\n",
- "from tensorflow.keras import models\n",
- "from tensorflow.keras.models import Model, load_model\n",
- "from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n",
- "from tensorflow.keras.optimizers import Adam\n",
- "# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n",
- "from tensorflow.keras.callbacks import ModelCheckpoint\n",
- "from tensorflow.keras.callbacks import ReduceLROnPlateau\n",
- "from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n",
- "from tensorflow.keras import backend as keras\n",
- "from tensorflow.keras.callbacks import Callback\n",
- "\n",
- "# General import\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "import os\n",
- "import glob\n",
- "from skimage import img_as_ubyte, io, transform\n",
- "import matplotlib as mpl\n",
- "from matplotlib import pyplot as plt\n",
- "from matplotlib.pyplot import imread\n",
- "from pathlib import Path\n",
- "import shutil\n",
- "import random\n",
- "import time\n",
- "import csv\n",
- "import sys\n",
- "from math import ceil\n",
- "from fpdf import FPDF, HTMLMixin\n",
- "from pip._internal.operations.freeze import freeze\n",
- "import subprocess\n",
- "# Imports for QC\n",
- "from PIL import Image\n",
- "from scipy import signal\n",
- "from scipy import ndimage\n",
- "from sklearn.linear_model import LinearRegression\n",
- "from skimage.util import img_as_uint\n",
- "from skimage.metrics import structural_similarity\n",
- "from skimage.metrics import peak_signal_noise_ratio as psnr\n",
- "\n",
- "# For sliders and dropdown menu and progress bar\n",
- "from ipywidgets import interact\n",
- "import ipywidgets as widgets\n",
- "# from tqdm import tqdm\n",
- "from tqdm.notebook import tqdm\n",
- "\n",
- "from sklearn.feature_extraction import image\n",
- "from skimage import img_as_ubyte, io, transform\n",
- "from skimage.util.shape import view_as_windows\n",
- "\n",
- "from datetime import datetime\n",
- "\n",
- "\n",
- "# Suppressing some warnings\n",
- "import warnings\n",
- "warnings.filterwarnings('ignore')\n",
- "\n",
- "# BioImage Model Zoo\n",
- "from shutil import rmtree\n",
- "import bioimageio.spec.model.v0_5 as bioimageio_spec\n",
- "from bioimageio.spec import save_bioimageio_package\n",
- "from bioimageio.spec._internal.io import FileDescr\n",
- "import bioimageio.core as bioimageio_core\n",
- "from zipfile import ZipFile\n",
- "import requests\n",
- "from bioimageio.spec.pretty_validation_errors import (\n",
- " enable_pretty_validation_errors_in_ipynb,\n",
- ")\n",
- "enable_pretty_validation_errors_in_ipynb()\n",
- "\n",
- "#Create a variable to get and store relative base path\n",
- "base_path = os.getcwd()\n",
- "\n",
- "def create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction):\n",
- " \"\"\"\n",
- " Function creates patches from the Training_source and Training_target images.\n",
- " The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n",
- " Saves all created patches in two new directories in the base_path folder.\n",
- "\n",
- " Returns: - Two paths to where the patches are now saved\n",
- " \"\"\"\n",
- " DEBUG = False\n",
- "\n",
- " Patch_source = os.path.join(base_path,'img_patches')\n",
- " Patch_target = os.path.join(base_path,'mask_patches')\n",
- " Patch_rejected = os.path.join(base_path,'rejected')\n",
- "\n",
- " #Here we save the patches, in the /content directory as they will not usually be needed after training\n",
- " if os.path.exists(Patch_source):\n",
- " shutil.rmtree(Patch_source)\n",
- " if os.path.exists(Patch_target):\n",
- " shutil.rmtree(Patch_target)\n",
- " if os.path.exists(Patch_rejected):\n",
- " shutil.rmtree(Patch_rejected)\n",
- "\n",
- " os.mkdir(Patch_source)\n",
- " os.mkdir(Patch_target)\n",
- " os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n",
- "\n",
- " patch_num = 0\n",
- " Training_source_list = [f for f in os.listdir(Training_source) if not f.startswith(\".\")]\n",
- " for file in tqdm(Training_source_list):\n",
- "\n",
- " img = io.imread(os.path.join(Training_source, file))\n",
- " mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n",
- "\n",
- " if DEBUG:\n",
- " print(file)\n",
- " print(img.dtype)\n",
- "\n",
- " # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n",
- " patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n",
- " patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n",
- "\n",
- " patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)\n",
- " patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)\n",
- "\n",
- " if DEBUG:\n",
- " print(all_patches_img.shape)\n",
- " print(all_patches_img.dtype)\n",
- "\n",
- " for i in range(patches_img.shape[0]):\n",
- " img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')\n",
- " mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')\n",
- " patch_num += 1\n",
- "\n",
- " # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n",
- " pixel_threshold_array = sorted(patches_mask[i].flatten())\n",
- " if pixel_threshold_array[int(round((len(pixel_threshold_array)-1)*(1-min_fraction)))]>0:\n",
- " io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))\n",
- " io.imsave(mask_save_path, patches_mask[i])\n",
- " else:\n",
- " io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_image.tif', img_as_ubyte(normalizeMinMax(patches_img[i])))\n",
- " io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_mask.tif', patches_mask[i])\n",
- "\n",
- " return Patch_source, Patch_target\n",
- "\n",
- "\n",
- "def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n",
- "\n",
- " files = [f for f in os.listdir(data_path) if not f.startswith(\".\")]\n",
- "\n",
- " # Get the size of the first image found in the folder and initialise the variables to that\n",
- " n = 0\n",
- " while os.path.isdir(os.path.join(data_path, files[n])):\n",
- " n += 1\n",
- " (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n",
- "\n",
- " # Screen the size of all dataset to find the minimum image size\n",
- " for file in files:\n",
- " if not os.path.isdir(os.path.join(data_path, file)):\n",
- " (height, width) = Image.open(os.path.join(data_path, file)).size\n",
- " if width < width_min:\n",
- " width_min = width\n",
- " if height < height_min:\n",
- " height_min = height\n",
- "\n",
- " # Find the power of patches that will fit within the smallest dataset\n",
- " width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n",
- "\n",
- " # Clip values at maximum permissible values\n",
- " if width_min > max_width:\n",
- " width_min = max_width\n",
- "\n",
- " if height_min > max_height:\n",
- " height_min = max_height\n",
- "\n",
- " return (width_min, height_min)\n",
- "\n",
- "def fittingPowerOfTwo(number):\n",
- " n = 0\n",
- " while 2**n <= number:\n",
- " n += 1\n",
- " return 2**(n-1)\n",
- "\n",
- "## TODO: create weighted CE for semantic labels\n",
- "def getClassWeights(Training_target_path):\n",
- "\n",
- " Mask_dir_list = [f for f in os.listdir(Training_target_path) if not f.startswith(\".\")]\n",
- "\n",
- " number_of_dataset = len(Mask_dir_list)\n",
- "\n",
- " class_count = np.zeros(2, dtype=int)\n",
- " for i in tqdm(range(number_of_dataset)):\n",
- " mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n",
- " mask = normalizeMinMax(mask)\n",
- " class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n",
- " class_count[1] += mask.sum()\n",
- "\n",
- " n_samples = class_count.sum()\n",
- " n_classes = 2\n",
- "\n",
- " class_weights = n_samples / (n_classes * class_count)\n",
- " return class_weights\n",
- "\n",
- "def weighted_binary_crossentropy(class_weights):\n",
- "\n",
- " def _weighted_binary_crossentropy(y_true, y_pred):\n",
- " binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n",
- " weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n",
- " weighted_binary_crossentropy = weight_vector * binary_crossentropy\n",
- "\n",
- " return keras.mean(weighted_binary_crossentropy)\n",
- "\n",
- " return _weighted_binary_crossentropy\n",
- "\n",
- "\n",
- "def save_augment(datagen,orig_img,dir_augmented_data = base_path + \"/augment\"):\n",
- " \"\"\"\n",
- " Saves a subset of the augmented data for visualisation, by default in /content.\n",
- "\n",
- " This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n",
- "\n",
- " \"\"\"\n",
- " try:\n",
- " os.mkdir(dir_augmented_data)\n",
- " except:\n",
- " ## if the preview folder exists, then remove\n",
- " ## the contents (pictures) in the folder\n",
- " dir_augmented_data_list = [f for f in os.listdir(dir_augmented_data) if not f.startswith(\".\")]\n",
- " for item in dir_augmented_data_list:\n",
- " os.remove(dir_augmented_data + \"/\" + item)\n",
- "\n",
- " ## convert the original image to array\n",
- " x = img_to_array(orig_img)\n",
- " ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n",
- " #print(x.shape)\n",
- " x = x.reshape((1,) + x.shape)\n",
- " #print(x.shape)\n",
- " ## -------------------------- ##\n",
- " ## randomly generate pictures\n",
- " ## -------------------------- ##\n",
- " i = 0\n",
- " #We will just save 5 images,\n",
- " #but this can be changed, but note the visualisation in 3. currently uses 5.\n",
- " Nplot = 5\n",
- " for batch in datagen.flow(x,batch_size=1,\n",
- " save_to_dir=dir_augmented_data,\n",
- " save_format='tif',\n",
- " seed=42):\n",
- " i += 1\n",
- " if i > Nplot - 1:\n",
- " break\n",
- "\n",
- "# Generators\n",
- "def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size, validatio_split):\n",
- " '''\n",
- " Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n",
- "\n",
- " datagen: ImageDataGenerator\n",
- " subset: can take either 'training' or 'validation'\n",
- " '''\n",
- "\n",
- " # Build the dict for the ImageDataGenerator\n",
- " # non_aug_args = dict(width_shift_range = 0,\n",
- " # height_shift_range = 0,\n",
- " # rotation_range = 0, #90\n",
- " # zoom_range = 0,\n",
- " # shear_range = 0,\n",
- " # horizontal_flip = False,\n",
- " # vertical_flip = False,\n",
- " # fill_mode = 'reflect')\n",
- " # default params of data generator is without augmentation\n",
- " mask_load_gen = ImageDataGenerator(dtype='uint8', validation_split=validatio_split)\n",
- " image_load_gen = ImageDataGenerator(dtype='float32', validation_split=validatio_split, preprocessing_function = normalizePercentile)\n",
- "\n",
- " image_generator = image_load_gen.flow_from_directory(\n",
- " os.path.dirname(image_folder_path),\n",
- " classes = [os.path.basename(image_folder_path)],\n",
- " class_mode = None,\n",
- " color_mode = \"grayscale\",\n",
- " target_size = target_size,\n",
- " batch_size = batch_size,\n",
- " subset = subset,\n",
- " interpolation = \"bicubic\",\n",
- " seed = 1)\n",
- " mask_generator = mask_load_gen.flow_from_directory(\n",
- " os.path.dirname(mask_folder_path),\n",
- " classes = [os.path.basename(mask_folder_path)],\n",
- " class_mode = None,\n",
- " color_mode = \"grayscale\",\n",
- " target_size = target_size,\n",
- " batch_size = batch_size,\n",
- " subset = subset,\n",
- " interpolation = \"nearest\",\n",
- " seed = 1)\n",
- "\n",
- " this_generator = zip(image_generator, mask_generator)\n",
- " for (img,mask) in this_generator:\n",
- " if subset == 'training':\n",
- " # Apply the data augmentation\n",
- " # the same seed should provide always the same transformation and image loading\n",
- " seed = np.random.randint(100000)\n",
- " for batch_im in image_datagen.flow(img,batch_size=batch_size, seed=seed):\n",
- " break\n",
- " mask = mask.astype(np.float32)\n",
- " labels = np.unique(mask)\n",
- " if len(labels)>1:\n",
- " batch_mask = np.zeros_like(mask, dtype='float32')\n",
- " for l in range(0, len(labels)):\n",
- " aux = (mask==l).astype(np.float32)\n",
- " for batch_aux in mask_datagen.flow(aux,batch_size=batch_size, seed=seed):\n",
- " break\n",
- " batch_mask += l*(batch_aux>0).astype(np.float32)\n",
- " index = np.where(batch_mask>l)\n",
- " batch_mask[index]=l\n",
- " else:\n",
- " batch_mask = mask\n",
- "\n",
- " yield (batch_im,batch_mask)\n",
- "\n",
- " else:\n",
- " yield (img,mask)\n",
- "\n",
- "\n",
- "def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512), validatio_split = 0.1):\n",
- " image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n",
- " mask_datagen = ImageDataGenerator(**datagen_parameters)\n",
- "\n",
- " train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size, validatio_split)\n",
- " validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size, validatio_split)\n",
- "\n",
- " return (train_datagen, validation_datagen)\n",
- "\n",
- "\n",
- "# Normalization functions from Martin Weigert\n",
- "def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n",
- " \"\"\"This function is adapted from Martin Weigert\"\"\"\n",
- " \"\"\"Percentile-based image normalization.\"\"\"\n",
- "\n",
- " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n",
- " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n",
- " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n",
- "\n",
- "\n",
- "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n",
- " \"\"\"This function is adapted from Martin Weigert\"\"\"\n",
- " if dtype is not None:\n",
- " x = x.astype(dtype,copy=False)\n",
- " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n",
- " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n",
- " eps = dtype(eps)\n",
- "\n",
- " try:\n",
- " import numexpr\n",
- " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n",
- " except ImportError:\n",
- " x = (x - mi) / ( ma - mi + eps )\n",
- "\n",
- " if clip:\n",
- " x = np.clip(x,0,1)\n",
- "\n",
- " return x\n",
- "\n",
- "\n",
- "\n",
- "# Simple normalization to min/max fir the Mask\n",
- "def normalizeMinMax(x, dtype=np.float32):\n",
- " x = x.astype(dtype,copy=False)\n",
- " x = (x - np.amin(x)) / (np.amax(x) - np.amin(x) + 1e-10)\n",
- " return x\n",
- "\n",
- "\n",
- "# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network.\n",
- "def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, labels=2):\n",
- " inputs = Input(input_size)\n",
- " conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n",
- " conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n",
- " # Downsampling steps\n",
- " pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n",
- " conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n",
- " conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n",
- "\n",
- " if pooling_steps > 1:\n",
- " pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n",
- " conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n",
- " conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n",
- "\n",
- " if pooling_steps > 2:\n",
- " pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n",
- " conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n",
- " conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n",
- " drop4 = Dropout(0.5)(conv4)\n",
- "\n",
- " if pooling_steps > 3:\n",
- " pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n",
- " conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n",
- " conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n",
- " drop5 = Dropout(0.5)(conv5)\n",
- "\n",
- " #Upsampling steps\n",
- " up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n",
- " merge6 = concatenate([drop4,up6], axis = 3)\n",
- " conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n",
- " conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n",
- "\n",
- " if pooling_steps > 2:\n",
- " up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n",
- " if pooling_steps > 3:\n",
- " up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n",
- " merge7 = concatenate([conv3,up7], axis = 3)\n",
- " conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n",
- " conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)\n",
- "\n",
- " if pooling_steps > 1:\n",
- " up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n",
- " if pooling_steps > 2:\n",
- " up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n",
- " merge8 = concatenate([conv2,up8], axis = 3)\n",
- " conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n",
- " conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)\n",
- "\n",
- " if pooling_steps == 1:\n",
- " up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n",
- " else:\n",
- " up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n",
- "\n",
- " merge9 = concatenate([conv1,up9], axis = 3)\n",
- " conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n",
- " conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n",
- " conv9 = Conv2D(labels, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n",
- " conv10 = Conv2D(labels, 1, activation = 'softmax')(conv9)\n",
- "\n",
- " model = Model(inputs = inputs, outputs = conv10)\n",
- "\n",
- " model.compile(optimizer = Adam(lr = learning_rate), loss = 'sparse_categorical_crossentropy')\n",
- "\n",
- " if verbose:\n",
- " model.summary()\n",
- "\n",
- " if(pretrained_weights):\n",
- " model.load_weights(pretrained_weights)\n",
- "\n",
- " return model\n",
- "\n",
- "# Custom callback showing sample prediction\n",
- "class SampleImageCallback(Callback):\n",
- "\n",
- " def __init__(self, model, sample_data, model_path, save=False):\n",
- " self.model = model\n",
- " self.sample_data = sample_data\n",
- " self.model_path = model_path\n",
- " self.save = save\n",
- "\n",
- " def on_epoch_end(self, epoch, logs={}):\n",
- " if np.mod(epoch,5) == 0:\n",
- " sample_predict = self.model.predict_on_batch(self.sample_data)\n",
- "\n",
- " f=plt.figure(figsize=(16,8))\n",
- " plt.subplot(1,labels+1,1)\n",
- " plt.imshow(self.sample_data[0,:,:,0], cmap='gray')\n",
- " plt.title('Sample source')\n",
- " plt.axis('off');\n",
- " for i in range(1, labels):\n",
- " plt.subplot(1,labels+1,i+1)\n",
- " plt.imshow(sample_predict[0,:,:,i], interpolation='nearest', cmap='magma')\n",
- " plt.title('Predicted label {}'.format(i))\n",
- " plt.axis('off');\n",
- "\n",
- " plt.subplot(1,labels+1,labels+1)\n",
- " plt.imshow(np.squeeze(np.argmax(sample_predict[0], axis=-1)), interpolation='nearest')\n",
- " plt.title('Semantic segmentation')\n",
- " plt.axis('off');\n",
- "\n",
- " plt.show()\n",
- "\n",
- " if self.save:\n",
- " plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')\n",
- " Patch_source_list = [f for f in os.listdir(Patch_source) if not f.startswith(\".\")]\n",
- " random_choice = random.choice(Patch_source_list)\n",
- "\n",
- "def predict_as_tiles(Image_path, model):\n",
- "\n",
- " # Read the data in and normalize\n",
- " Image_raw = io.imread(Image_path, as_gray = True)\n",
- " Image_raw = normalizePercentile(Image_raw)\n",
- "\n",
- " # Get the patch size from the input layer of the model\n",
- " patch_size = model.layers[0].output_shape[0][1:3]\n",
- "\n",
- " # Pad the image with zeros if any of its dimensions is smaller than the patch size\n",
- " if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n",
- " Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n",
- " Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n",
- " else:\n",
- " Image = Image_raw\n",
- "\n",
- " # Calculate the number of patches in each dimension\n",
- " n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n",
- " n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n",
- "\n",
- " prediction = np.zeros(Image.shape, dtype = 'uint8')\n",
- "\n",
- " for x in range(n_patch_in_width):\n",
- " for y in range(n_patch_in_height):\n",
- " xi = patch_size[0]*x\n",
- " yi = patch_size[1]*y\n",
- "\n",
- " # If the patch exceeds the edge of the image shift it back\n",
- " if xi+patch_size[0] >= Image.shape[0]:\n",
- " xi = Image.shape[0]-patch_size[0]\n",
- "\n",
- " if yi+patch_size[1] >= Image.shape[1]:\n",
- " yi = Image.shape[1]-patch_size[1]\n",
- "\n",
- " # Extract and reshape the patch\n",
- " patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n",
- " patch = np.reshape(patch,patch.shape+(1,))\n",
- " patch = np.reshape(patch,(1,)+patch.shape)\n",
- "\n",
- " # Get the prediction from the patch and paste it in the prediction in the right place\n",
- " predicted_patch = model.predict(patch, batch_size = 1)\n",
- " prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = (np.argmax(np.squeeze(predicted_patch), axis = -1)).astype(np.uint8)\n",
- "\n",
- "\n",
- " return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n",
- "\n",
- "\n",
- "def saveResult(save_path, nparray, source_dir_list, prefix=''):\n",
- " for (filename, image) in zip(source_dir_list, nparray):\n",
- " io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), image) # saving as unsigned 8-bit image\n",
- "\n",
- "\n",
- "def convert2Mask(image, threshold):\n",
- " mask = img_as_ubyte(image, force_copy=True)\n",
- " mask[mask > threshold] = 255\n",
- " mask[mask <= threshold] = 0\n",
- " return mask\n",
- "\n",
- "# BMZ model export functions\n",
- "def make_author(author_input_info: str):\n",
- " \"\"\"\n",
- " Create an Author object from a string input.\n",
- "\n",
- " Args:\n",
- " author_input_info: A string containing the author's name and affiliation.\n",
- "\n",
- " Returns:\n",
- " An Author object\n",
- " \"\"\"\n",
- " if author_input_info.strip() == '':\n",
- " return None\n",
- "\n",
- " auth_order = ['name', 'affiliation', 'email', 'orcid', 'github_user']\n",
- " auth_dict = {}\n",
- "\n",
- " auth_info_split = author_input_info.split(',')\n",
- "\n",
- " for i in range(len(auth_info_split)):\n",
- " if auth_info_split[i].strip() == 'None' or auth_info_split[i].strip() == '':\n",
- " continue\n",
- " else:\n",
- " auth_dict[auth_order[i]] = auth_info_split[i].strip()\n",
- "\n",
- " return bioimageio_spec.Author(**auth_dict)\n",
- "\n",
- "def make_maintainer(maintainer_input_info: str):\n",
- " \"\"\"\n",
- " Create an Author object from a string input.\n",
- "\n",
- " Args:\n",
- " author_input_info: A string containing the author's name and affiliation.\n",
- "\n",
- " Returns:\n",
- " An Author object\n",
- " \"\"\"\n",
- " maint_order = [ 'github_user', 'name', 'affiliation', 'email', 'orcid']\n",
- " maint_dict = {}\n",
- "\n",
- " maint_info_split = maintainer_input_info.split(',')\n",
- "\n",
- " for i in range(len(maint_info_split)):\n",
- " if maint_info_split[i].strip() == 'None' or maint_info_split[i].strip() == '':\n",
- " continue\n",
- " else:\n",
- " maint_dict[maint_order[i]] = maint_info_split[i].strip()\n",
- "\n",
- " return bioimageio_spec.Maintainer(**maint_dict)\n",
- "\n",
- "\n",
- "# -------------- Other definitions -----------\n",
- "W = '\\033[0m' # white (normal)\n",
- "R = '\\033[31m' # red\n",
- "prediction_prefix = 'Predicted_'\n",
- "\n",
- "\n",
- "print('-------------------')\n",
- "print('U-Net and dependencies installed.')\n",
- "\n",
- "# Colors for the warning messages\n",
- "class bcolors:\n",
- " WARNING = '\\033[31m'\n",
- "\n",
- "# Check if this is the latest version of the notebook\n",
- "\n",
- "All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n",
- "print('Notebook version: '+Notebook_version)\n",
- "Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n",
- "print('Latest notebook version: '+Latest_Notebook_version)\n",
- "if Notebook_version == Latest_Notebook_version:\n",
- " print(\"This notebook is up-to-date.\")\n",
- "else:\n",
- " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n",
- "\n",
- "\n",
- "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n",
- " class MyFPDF(FPDF, HTMLMixin):\n",
- " pass\n",
- "\n",
- " pdf = MyFPDF()\n",
- " pdf.add_page()\n",
- " pdf.set_right_margin(-1)\n",
- " pdf.set_font(\"Arial\", size = 11, style='B')\n",
- "\n",
- " day = datetime.now()\n",
- " datetime_str = str(day)[0:10]\n",
- "\n",
- " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n",
- " pdf.multi_cell(180, 5, txt = Header, align = 'L')\n",
- " pdf.ln(1)\n",
- "\n",
- " # add another cell\n",
- " if trained:\n",
- " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n",
- " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n",
- " pdf.ln(1)\n",
- "\n",
- " Header_2 = 'Information for your materials and method:'\n",
- " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n",
- " pdf.ln(1)\n",
- "\n",
- " all_packages = ''\n",
- " for requirement in freeze(local_only=True):\n",
- " all_packages = all_packages+requirement+', '\n",
- " #print(all_packages)\n",
- "\n",
- " #Main Packages\n",
- " main_packages = ''\n",
- " version_numbers = []\n",
- " for name in ['tensorflow','numpy','keras']:\n",
- " find_name=all_packages.find(name)\n",
- " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n",
- " #Version numbers only here:\n",
- " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n",
- "\n",
- " try:\n",
- " cuda_version = subprocess.run([\"nvcc\",\"--version\"],stdout=subprocess.PIPE)\n",
- " cuda_version = cuda_version.stdout.decode('utf-8')\n",
- " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n",
- " except:\n",
- " cuda_version = ' - No cuda found - '\n",
- " try:\n",
- " gpu_name = subprocess.run([\"nvidia-smi\"],stdout=subprocess.PIPE)\n",
- " gpu_name = gpu_name.stdout.decode('utf-8')\n",
- " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n",
- " except:\n",
- " gpu_name = ' - No GPU found - '\n",
- " #print(cuda_version[cuda_version.find(', V')+3:-1])\n",
- " #print(gpu_name)\n",
- " loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]\n",
- " Training_source_list = [f for f in os.listdir(Training_source) if not f.startswith(\".\")]\n",
- " shape = io.imread(Training_source+'/' + Training_source_list[1]).shape\n",
- " dataset_size = len(Training_source_list)\n",
- "\n",
- " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n",
- "\n",
- " if pretrained_model:\n",
- " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n",
- "\n",
- " pdf.set_font('')\n",
- " pdf.set_font_size(10.)\n",
- " pdf.multi_cell(180, 5, txt = text, align='L')\n",
- " pdf.ln(1)\n",
- " pdf.set_font('')\n",
- " pdf.set_font('Arial', size = 10, style = 'B')\n",
- " pdf.cell(28, 5, txt='Augmentation: ', ln=1)\n",
- " pdf.set_font('')\n",
- " if augmentation:\n",
- " aug_text = 'The dataset was augmented by'\n",
- " if rotation_range != 0:\n",
- " aug_text = aug_text+'\\n- rotation'\n",
- " if horizontal_flip == True or vertical_flip == True:\n",
- " aug_text = aug_text+'\\n- flipping'\n",
- " if zoom_range != 0:\n",
- " aug_text = aug_text+'\\n- random zoom magnification'\n",
- " if horizontal_shift != 0 or vertical_shift != 0:\n",
- " aug_text = aug_text+'\\n- shifting'\n",
- " if shear_range != 0:\n",
- " aug_text = aug_text+'\\n- image shearing'\n",
- " else:\n",
- " aug_text = 'No augmentation was used for training.'\n",
- " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n",
- " pdf.ln(1)\n",
- " pdf.set_font('Arial', size = 11, style = 'B')\n",
- " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n",
- " pdf.set_font('')\n",
- " pdf.set_font_size(10.)\n",
- " if Use_Default_Advanced_Parameters:\n",
- " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n",
- " pdf.cell(200, 5, txt='The following parameters were used for training:')\n",
- " pdf.ln(3)\n",
- " html = \"\"\"\n",
- " \n",
- " \n",
- " Parameter | \n",
- " Value | \n",
- "
\n",
- " \n",
- " number_of_epochs | \n",
- " {0} | \n",
- "
\n",
- " \n",
- " patch_size | \n",
- " {1} | \n",
- "
\n",
- " \n",
- " batch_size | \n",
- " {2} | \n",
- "
\n",
- " \n",
- " number_of_steps | \n",
- " {3} | \n",
- "
\n",
- " \n",
- " percentage_validation | \n",
- " {4} | \n",
- "
\n",
- " \n",
- " initial_learning_rate | \n",
- " {5} | \n",
- "
\n",
- " \n",
- " pooling_steps | \n",
- " {6} | \n",
- "
\n",
- " \n",
- " min_fraction | \n",
- " {7} | \n",
- "
\n",
- " \"\"\".format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)\n",
- " pdf.write_html(html)\n",
- "\n",
- " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n",
- " pdf.set_font(\"Arial\", size = 11, style='B')\n",
- " pdf.ln(1)\n",
- " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n",
- " pdf.set_font('')\n",
- " pdf.set_font('Arial', size = 10, style = 'B')\n",
- " pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n",
- " pdf.set_font('')\n",
- " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n",
- " pdf.ln(1)\n",
- " pdf.set_font('')\n",
- " pdf.set_font('Arial', size = 10, style = 'B')\n",
- " pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n",
- " pdf.set_font('')\n",
- " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n",
- " pdf.ln(1)\n",
- " pdf.set_font('')\n",
- " pdf.set_font('Arial', size = 10, style = 'B')\n",
- " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n",
- " pdf.set_font('')\n",
- " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n",
- " pdf.ln(1)\n",
- " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n",
- " pdf.ln(1)\n",
- " exp_size = io.imread(base_path + '/TrainingDataExample_Unet2D.png').shape\n",
- " pdf.image(base_path + '/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n",
- " pdf.ln(1)\n",
- " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n",
- " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n",
- " pdf.ln(1)\n",
- " ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n",
- " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n",
- " # if Use_Data_augmentation:\n",
- " # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n",
- " # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n",
- " pdf.ln(3)\n",
- " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n",
- " pdf.set_font('Arial', size = 11, style='B')\n",
- " pdf.multi_cell(190, 5, txt=reminder, align='C')\n",
- " pdf.ln(1)\n",
- "\n",
- " pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n",
- "\n",
- " print('------------------------------')\n",
- " print('PDF report exported in '+model_path+'/'+model_name+'/')\n",
- "\n",
- "def qc_pdf_export():\n",
- " class MyFPDF(FPDF, HTMLMixin):\n",
- " pass\n",
- "\n",
- " pdf = MyFPDF()\n",
- " pdf.add_page()\n",
- " pdf.set_right_margin(-1)\n",
- " pdf.set_font(\"Arial\", size = 11, style='B')\n",
- "\n",
- " Network = 'Unet 2D'\n",
- "\n",
- " day = datetime.now()\n",
- " datetime_str = str(day)[0:10]\n",
- "\n",
- " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n",
- " pdf.multi_cell(180, 5, txt = Header, align = 'L')\n",
- " pdf.ln(1)\n",
- "\n",
- " all_packages = ''\n",
- " for requirement in freeze(local_only=True):\n",
- " all_packages = all_packages+requirement+', '\n",
- "\n",
- " pdf.set_font('')\n",
- " pdf.set_font('Arial', size = 11, style = 'B')\n",
- " pdf.ln(2)\n",
- " pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n",
- " pdf.ln(1)\n",
- " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n",
- " if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n",
- " pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n",
- " else:\n",
- " pdf.set_font('')\n",
- " pdf.set_font('Arial', size=10)\n",
- " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n",
- " pdf.ln(2)\n",
- " pdf.set_font('')\n",
- " pdf.set_font('Arial', size = 10, style = 'B')\n",
- " pdf.ln(3)\n",
- " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n",
- " pdf.ln(1)\n",
- " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n",
- " pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n",
- " pdf.ln(1)\n",
- " pdf.set_font('')\n",
- " pdf.set_font('Arial', size = 11, style = 'B')\n",
- " pdf.ln(1)\n",
- " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n",
- " pdf.set_font('')\n",
- " pdf.set_font_size(10.)\n",
- "\n",
- " pdf.ln(1)\n",
- " html = \"\"\"\n",
- " \n",
- " \n",
- " \"\"\"\n",
- " with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n",
- " metrics = csv.reader(csvfile)\n",
- " header = next(metrics)\n",
- " image = header[0]\n",
- " IoU = header[-1]\n",
- " header = \"\"\"\n",
- " \n",
- " {0} | \n",
- " {1} | \n",
- "
\"\"\".format(image,IoU)\n",
- " html = html+header\n",
- " i=0\n",
- " for row in metrics:\n",
- " i+=1\n",
- " image = row[0]\n",
- " IoU = row[-1]\n",
- " cells = \"\"\"\n",
- " \n",
- " {0} | \n",
- " {1} | \n",
- "
\"\"\".format(image,str(round(float(IoU),3)))\n",
- " html = html+cells\n",
- " html = html+\"\"\"
\"\"\"\n",
- "\n",
- " pdf.write_html(html)\n",
- "\n",
- " pdf.ln(1)\n",
- " pdf.set_font('')\n",
- " pdf.set_font_size(10.)\n",
- " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n",
- " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n",
- " pdf.ln(1)\n",
- " ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n",
- " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n",
- " pdf.ln(3)\n",
- " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n",
- "\n",
- " pdf.set_font('Arial', size = 11, style='B')\n",
- " pdf.multi_cell(190, 5, txt=reminder, align='C')\n",
- " pdf.ln(1)\n",
- "\n",
- " pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n",
- "\n",
- " print('------------------------------')\n",
- " print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n",
- "\n",
- "# Build requirements file for local run\n",
- "after = [str(m) for m in sys.modules]\n",
- "build_requirements_file(before, after)\n",
- "!pip3 freeze > requirements.txt"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "n4yWFoJNnoin"
- },
- "source": [
- "# **2. Complete the Colab session**\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "---\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "DMNHVZfHmbKb"
- },
- "source": [
- "\n",
- "## **2.1. Check for GPU access**\n",
- "---\n",
- "\n",
- "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n",
- "\n",
- "Go to **Runtime -> Change the Runtime type**\n",
- "\n",
- "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n",
- "\n",
- "**Accelerator: GPU** *(Graphics processing unit)*\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "zCvebubeSaGY"
- },
- "outputs": [],
- "source": [
- "#@markdown ##Run this cell to check if you have GPU access\n",
- "\n",
- "!if type nvidia-smi >/dev/null 2>&1; then \\\n",
- " echo \"You have GPU access\"; nvidia-smi; \\\n",
- " else \\\n",
- " echo -e \"You do not have GPU access.\\nDid you change your runtime?\\nIf the runtime setting is correct then Google did not allocate a GPU for your session\\nExpect slow performance. To access GPU try reconnecting later\"; fi"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "sNIVx8_CLolt"
- },
- "source": [
- "## **2.2. Mount your Google Drive**\n",
- "---\n",
- " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n",
- "\n",
- " Play the cell below to mount your Google Drive. Click on **Connect to Google Drive** and a window will pop up. You will need to sign in tour Google Account, follow the steps and click **Continue**. This will give Colab access to the data on the drive.\n",
- "\n",
- " Once this is done, your data are available in the **Files** tab on the top left of notebook."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "01Djr8v-5pPk"
- },
- "outputs": [],
- "source": [
- "#@markdown ##Play the cell to connect your Google Drive to Colab\n",
- "\n",
- "#@markdown * Click on **Connect to Google Drive**.\n",
- "\n",
- "#@markdown * A new window, will pop up. Sign in your Google Account.\n",
- "\n",
- "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\".\n",
- "\n",
- "# mount user's Google Drive to Google Colab.\n",
- "from google.colab import drive\n",
- "drive.mount('/content/gdrive')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "dm3eCMYB5d-H"
- },
- "source": [
- "** If you cannot see your files, reactivate your session by connecting to your hosted runtime.**\n",
- "\n",
- "\n",
- " Connect to a hosted runtime. "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "HLYcZR9gMv42"
- },
- "source": [
- "# **3. Select your parameters and paths**\n",
- "---"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "FQ_QxtSWQ7CL"
- },
- "source": [
- "## **3.1. Setting main training parameters**\n",
- "---\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "AuESFimvMv43"
- },
- "source": [
- " **Paths for training data and models**\n",
- "\n",
- "**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (semantic segmentation masks). The mask should be a unique 2D image with values 0, 1, 2, ... each of them corresponding to a semantic definition of the content in the image. The values should be ordered from the lowest to the highest and without missing any value in between (unless it is missing in the image). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n",
- "\n",
- "**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n",
- "\n",
- "**`model_path`**: Enter the path of the folder where you want to save your model.\n",
- "\n",
- "**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n",
- "\n",
- "**`labels`**: The number of different labels that the network needs to learn, which also includes the background. For example: to segment two different kind of objects in an image (cats and dogs), labels = 3 (2 labels for the two kinds and one more label for the background).\n",
- "\n",
- "\n",
- " **Select training parameters**\n",
- "\n",
- "**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n",
- "\n",
- "**Advanced parameters - experienced users only**\n",
- "\n",
- "**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n",
- "\n",
- "**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.\n",
- "\n",
- " **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n",
- "\n",
- "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10**\n",
- "\n",
- "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n",
- "\n",
- "**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n",
- "\n",
- "**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "ewpNJ_I0Mv47"
- },
- "outputs": [],
- "source": [
- "# ------------- Initial user input ------------\n",
- "#@markdown ###Path to training images:\n",
- "Training_source = '' #@param {type:\"string\"}\n",
- "Training_target = '' #@param {type:\"string\"}\n",
- "\n",
- "model_name = '' #@param {type:\"string\"}\n",
- "model_path = '' #@param {type:\"string\"}\n",
- "\n",
- "labels = 3 #@param {type:\"number\"}\n",
- "\n",
- "#@markdown ###Training parameters:\n",
- "#@markdown Number of epochs\n",
- "number_of_epochs = 10#@param {type:\"number\"}\n",
- "\n",
- "#@markdown ###Advanced parameters:\n",
- "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n",
- "\n",
- "#@markdown ###If not, please input:\n",
- "batch_size = 5#@param {type:\"integer\"}\n",
- "number_of_steps = 0#@param {type:\"number\"}\n",
- "pooling_steps = 3 #@param [1,2,3,4]{type:\"raw\"}\n",
- "percentage_validation = 10#@param{type:\"number\"}\n",
- "initial_learning_rate = 0.0001 #@param {type:\"number\"}\n",
- "\n",
- "patch_width = 320#@param{type:\"number\"}\n",
- "patch_height = 320#@param{type:\"number\"}\n",
- "min_fraction = 0.05#@param{type:\"number\"}\n",
- "\n",
- "\n",
- "# ------------- Initialising folder, variables and failsafes ------------\n",
- "# Create the folders where to save the model and the QC\n",
- "full_model_path = os.path.join(model_path, model_name)\n",
- "if os.path.exists(full_model_path):\n",
- " print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n",
- "\n",
- "if (Use_Default_Advanced_Parameters):\n",
- " print(\"Default advanced parameters enabled\")\n",
- " batch_size = 4\n",
- " pooling_steps = 2\n",
- " percentage_validation = 10\n",
- " initial_learning_rate = 0.0003\n",
- " patch_width, patch_height = estimatePatchSize(Training_source)\n",
- " min_fraction = 0.02\n",
- "\n",
- "\n",
- "#The create_patches function will create the two folders below\n",
- "# Patch_source = '/content/img_patches'\n",
- "# Patch_target = '/content/mask_patches'\n",
- "print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n",
- "\n",
- "#Create patches\n",
- "print('Creating patches...')\n",
- "Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction)\n",
- "\n",
- "number_of_training_dataset = len(os.listdir(Patch_source))\n",
- "print('Total number of valid patches: '+str(number_of_training_dataset))\n",
- "\n",
- "if Use_Default_Advanced_Parameters or number_of_steps == 0:\n",
- " number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n",
- "print('Number of steps: '+str(number_of_steps))\n",
- "\n",
- "# Calculate the number of steps to use for validation\n",
- "validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n",
- "validatio_split = percentage_validation/100\n",
- "\n",
- "# Here we disable pre-trained model by default (in case the next cell is not ran)\n",
- "Use_pretrained_model = False\n",
- "# Here we disable data augmentation by default (in case the cell is not ran)\n",
- "Use_Data_augmentation = False\n",
- "# Build the default dict for the ImageDataGenerator\n",
- "data_gen_args = dict(width_shift_range = 0.,\n",
- " height_shift_range = 0.,\n",
- " rotation_range = 0., #90\n",
- " zoom_range = 0.,\n",
- " shear_range = 0.,\n",
- " horizontal_flip = False,\n",
- " vertical_flip = False,\n",
- " validation_split = percentage_validation/100,\n",
- " fill_mode = 'reflect')\n",
- "\n",
- "# ------------- Display ------------\n",
- "\n",
- "#if not os.path.exists('/content/img_patches/'):\n",
- "random_choice = random.choice(os.listdir(Patch_source))\n",
- "x = io.imread(os.path.join(Patch_source, random_choice))\n",
- "\n",
- "#os.chdir(Training_target)\n",
- "y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n",
- "\n",
- "f=plt.figure(figsize=(16,8))\n",
- "plt.subplot(1,2,1)\n",
- "plt.imshow(x, interpolation='nearest',cmap='gray')\n",
- "plt.title('Training image patch')\n",
- "plt.axis('off');\n",
- "\n",
- "plt.subplot(1,2,2)\n",
- "plt.imshow(y, interpolation='nearest',cmap='gray')\n",
- "plt.title('Training mask patch')\n",
- "plt.axis('off');\n",
- "\n",
- "plt.savefig(base_path + '/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "w_jCy7xOx2g3"
- },
- "source": [
- "## **3.2. Data augmentation**\n",
- "\n",
- "---\n",
- "\n",
- " Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n",
- "\n",
- " The augmentation options below are to be used as follows:\n",
- "\n",
- "* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n",
- "* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n",
- "* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n",
- "* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n",
- "* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "DMqWq5-AxnFU"
- },
- "outputs": [],
- "source": [
- "#@markdown ##**Augmentation options**\n",
- "\n",
- "Use_Data_augmentation = True #@param {type:\"boolean\"}\n",
- "Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n",
- "\n",
- "if Use_Data_augmentation:\n",
- " if Use_Default_Augmentation_Parameters:\n",
- " horizontal_shift = 10\n",
- " vertical_shift = 20\n",
- " zoom_range = 10\n",
- " shear_range = 10\n",
- " horizontal_flip = True\n",
- " vertical_flip = True\n",
- " rotation_range = 180\n",
- "#@markdown ###If you are not using the default settings, please provide the values below:\n",
- "\n",
- "#@markdown ###**Image shift, zoom, shear and flip (%)**\n",
- " else:\n",
- " horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n",
- " vertical_shift = 13 #@param {type:\"slider\", min:0, max:100, step:1}\n",
- " zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n",
- " shear_range = 14 #@param {type:\"slider\", min:0, max:100, step:1}\n",
- " horizontal_flip = True #@param {type:\"boolean\"}\n",
- " vertical_flip = True #@param {type:\"boolean\"}\n",
- "\n",
- "#@markdown ###**Rotate image within angle range (degrees):**\n",
- " rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n",
- "\n",
- "#given behind the # are the default values for each parameter.\n",
- "\n",
- "else:\n",
- " horizontal_shift = 0\n",
- " vertical_shift = 0\n",
- " zoom_range = 0\n",
- " shear_range = 0\n",
- " horizontal_flip = False\n",
- " vertical_flip = False\n",
- " rotation_range = 0\n",
- "\n",
- "\n",
- "# Build the dict for the ImageDataGenerator\n",
- "data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n",
- " height_shift_range = vertical_shift/100.,\n",
- " rotation_range = rotation_range, #90\n",
- " zoom_range = zoom_range/100.,\n",
- " shear_range = shear_range/100.,\n",
- " horizontal_flip = horizontal_flip,\n",
- " vertical_flip = vertical_flip,\n",
- " validation_split = percentage_validation/100,\n",
- " fill_mode = 'reflect')\n",
- "\n",
- "\n",
- "\n",
- "# ------------- Display ------------\n",
- "dir_augmented_data_imgs = base_path + \"/augment_img\"\n",
- "dir_augmented_data_masks = base_path + \"/augment_mask\"\n",
- "random_choice = random.choice(os.listdir(Patch_source))\n",
- "orig_img = load_img(os.path.join(Patch_source,random_choice))\n",
- "orig_mask = load_img(os.path.join(Patch_target,random_choice))\n",
- "\n",
- "augment_view = ImageDataGenerator(**data_gen_args)\n",
- "\n",
- "if Use_Data_augmentation:\n",
- " print(\"Parameters enabled\")\n",
- " print(\"Here is what a subset of your augmentations looks like:\")\n",
- " save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n",
- " save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n",
- "\n",
- " fig = plt.figure(figsize=(15, 7))\n",
- " fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n",
- "\n",
- "\n",
- " ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[])\n",
- " new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n",
- " ax.imshow(new_img)\n",
- " ax.set_title('Original Image')\n",
- " i = 2\n",
- " for imgnm in os.listdir(dir_augmented_data_imgs):\n",
- " ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[])\n",
- " img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n",
- " ax.imshow(img)\n",
- " i += 1\n",
- "\n",
- " ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[])\n",
- " new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n",
- " ax.imshow(new_mask)\n",
- " ax.set_title('Original Mask')\n",
- " j=2\n",
- " for imgnm in os.listdir(dir_augmented_data_masks):\n",
- " ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[])\n",
- " mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n",
- " ax.imshow(mask)\n",
- " j += 1\n",
- " plt.show()\n",
- "\n",
- "else:\n",
- " print(\"No augmentation will be used\")\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3L9zSGtORKYI"
- },
- "source": [
- "\n",
- "## **3.3. Using weights from a pre-trained model as initial weights**\n",
- "---\n",
- " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**.\n",
- "\n",
- " **You do not need to run this section if you want to train a network from scratch**.\n",
- "\n",
- " This option allows you to use pre-trained models from the [BioImage Model Zoo](https://bioimage.io/#/) and fine-tune them to analyse new data. Choose `bioimageio_model` and provide the ID in `bioimageio_model_id` (e.g., \"placid-llama\" or \"10.5281/zenodo.5817052\").\n",
- "\n",
- " This option also allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. Choose `Model_from_file` and provide the `pretrained_model_path`.\n",
- "\n",
- " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "9vC2n-HeLdiJ"
- },
- "outputs": [],
- "source": [
- "# @markdown ##Loading weights from a pre-trained network\n",
- "from bioimageio.core import load_description\n",
- "from bioimageio.spec.utils import download\n",
- "\n",
- "Use_pretrained_model = False #@param {type:\"boolean\"}\n",
- "pretrained_model_choice = \"BioImage Model Zoo\" #@param [\"Model_from_file\", \"BioImage Model Zoo\"]\n",
- "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n",
- "\n",
- "\n",
- "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n",
- "pretrained_model_path = \"\" #@param {type:\"string\"}\n",
- "\n",
- "#@markdown ###If you chose \"BioImage Model Zoo\", please provide the path or the URL to the model:\n",
- "bioimageio_model_id = \"\" #@param {type:\"string\"}\n",
- "\n",
- "# --------------------- Check if we load a previously trained model ------------------------\n",
- "if Use_pretrained_model:\n",
- "\n",
- "# --------------------- Load the model from the choosen path ------------------------\n",
- " if pretrained_model_choice == \"Model_from_file\":\n",
- " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n",
- " qc_path = os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')\n",
- " elif pretrained_model_choice == \"BioImage Model Zoo\":\n",
- " model_spec = load_description(bioimageio_model_id)\n",
- " if \"keras_hdf5\" not in model_spec.weights.model_fields_set:\n",
- " print(\"Invalid bioimageio model\")\n",
- " h5_file_path = \"no-model\"\n",
- " qc_path = \"no-qc\"\n",
- " else:\n",
- " h5_file_path = str(download(model_spec.weights.keras_hdf5.source).path)\n",
- " try:\n",
- " attachments = model_spec.attachments.files\n",
- " qc_path = str(download([fname for fname in attachments if str(fname).endswith(\"training_evaluation.csv\")][0]).path)\n",
- " # qc_path = os.path.join(base_path + \"//bioimageio_pretrained_model\", qc_path)\n",
- " except Exception:\n",
- " qc_path = \"no-qc\"\n",
- "\n",
- "# --------------------- Check the model exist ------------------------\n",
- "\n",
- " if not os.path.exists(h5_file_path):\n",
- " # If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled,\n",
- " print(R+'WARNING: pretrained model does not exist')\n",
- " Use_pretrained_model = False\n",
- " else:\n",
- " # If the model path contains a pretrain model, we load the training rate\n",
- "\n",
- " if os.path.exists(qc_path):\n",
- " #Here we check if the learning rate can be loaded from the quality control folder\n",
- " # if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n",
- "\n",
- " # with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n",
- " with open(qc_path,'r') as csvfile:\n",
- " csvRead = pd.read_csv(csvfile, sep=',')\n",
- " #print(csvRead)\n",
- "\n",
- " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n",
- " print(\"A 'learning rate' attribute was found on provided pre-trained models.\")\n",
- " #find the last learning rate\n",
- " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n",
- " #Find the learning rate corresponding to the lowest validation loss\n",
- " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n",
- " #print(min_val_loss)\n",
- " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n",
- "\n",
- " if Weights_choice == \"last\":\n",
- " print(f'You will be loading \\033[1mlast\\033[0m learning rate: {lastLearningRate}')\n",
- " elif Weights_choice == \"best\":\n",
- " print(f'You will be loading the learning rate of \\033[1mbest\\033[0m validation loss: {bestLearningRate}')\n",
- " else:\n",
- " #if the column does not exist, then initial learning rate is used instead\n",
- " print(f\"{bcolors.WARNING}WARNING: The learning rate cannot be identified from the pretrained network{W}\")\n",
- " print(f\"{bcolors.WARNING}Default learning rate of {initial_learning_rate} will be used instead{W}\")\n",
- "\n",
- " bestLearningRate = initial_learning_rate\n",
- " lastLearningRate = initial_learning_rate\n",
- " else:\n",
- " #Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n",
- " print(f\"{bcolors.WARNING}Sorry, 'training_evaluation.csv' does not exists or was not correctly loaded.{W}\")\n",
- " print(f\"{bcolors.WARNING}Default learning rate of {initial_learning_rate} will be used instead{W}\")\n",
- " bestLearningRate = initial_learning_rate\n",
- " lastLearningRate = initial_learning_rate\n",
- "\n",
- "\n",
- "# Display info about the pretrained model to be loaded (or not)\n",
- "if Use_pretrained_model:\n",
- " print('-'*50)\n",
- " print(f'Weights found in: {h5_file_path}')\n",
- " print('They will be loaded prior to training.')\n",
- "\n",
- "else:\n",
- " print(R+'No pretrained network will be used.')\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "MCGklf1vZf2M"
- },
- "source": [
- "\n",
- "# **4. Train the network**\n",
- "---\n",
- "#### **Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1KYOuygETJkT"
- },
- "source": [
- "## **4.1. Prepare the training data and model for training**\n",
- "---\n",
- "Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "lIUAOJ_LMv5E"
- },
- "outputs": [],
- "source": [
- "#@markdown ##Play this cell to prepare the model for training\n",
- "\n",
- "\n",
- "# ------------------ Set the generators, model and logger ------------------\n",
- "# This will take the image size and set that as a patch size (arguable...)\n",
- "# Read image size (without actuall reading the data)\n",
- "\n",
- "(train_datagen, validation_datagen) = prepareGenerators(Patch_source,\n",
- " Patch_target,\n",
- " data_gen_args,\n",
- " batch_size,\n",
- " target_size = (patch_width, patch_height),\n",
- " validatio_split = validatio_split)\n",
- "\n",
- "\n",
- "# This modelcheckpoint will only save the best model from the validation loss point of view\n",
- "model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'),\n",
- " monitor='val_loss',verbose=1, save_best_only=True)\n",
- "\n",
- "# --------------------- Using pretrained model ------------------------\n",
- "#Here we ensure that the learning rate set correctly when using pre-trained models\n",
- "if Use_pretrained_model:\n",
- " if Weights_choice == \"last\":\n",
- " initial_learning_rate = lastLearningRate\n",
- "\n",
- " if Weights_choice == \"best\":\n",
- " initial_learning_rate = bestLearningRate\n",
- "else:\n",
- " h5_file_path = None\n",
- "\n",
- "# --------------------- ---------------------- ------------------------\n",
- "\n",
- "# --------------------- Reduce learning rate on plateau ------------------------\n",
- "\n",
- "reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, verbose=1,\n",
- " mode='auto', patience=20, min_lr=0)\n",
- "# --------------------- ---------------------- ------------------------\n",
- "\n",
- "# Define the model\n",
- "model = unet(input_size = (patch_width,patch_height,1),\n",
- " pooling_steps = pooling_steps,\n",
- " learning_rate = initial_learning_rate,\n",
- " labels = labels)\n",
- "\n",
- "# --------------------- Using pretrained model ------------------------\n",
- "# Load the pretrained weights\n",
- "if Use_pretrained_model:\n",
- " try:\n",
- " print(\"Weights correctly loaded.\")\n",
- " model.load_weights(h5_file_path)\n",
- " except:\n",
- " print(f\"{bcolors.WARNING}The pretrained model could not be loaded as the configuration of the network is different.\")\n",
- " print(\"Please, read the model specifications and check the parameters in Section 3.1\")\n",
- " print(f\"It might probably be the pooling steps attribute, please take a look to it.{W}\")\n",
- "\n",
- "config_model= model.optimizer.get_config()\n",
- "print(\"Configuration of model's optimizer:\")\n",
- "for k,v in config_model.items():\n",
- " print(f\"{k} : {v}\")\n",
- "\n",
- "\n",
- "# ------------------ Failsafes ------------------\n",
- "if os.path.exists(full_model_path):\n",
- " print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n",
- " shutil.rmtree(full_model_path)\n",
- "\n",
- "os.makedirs(full_model_path)\n",
- "os.makedirs(os.path.join(full_model_path,'Quality Control'))\n",
- "\n",
- "\n",
- "# ------------------ Display ------------------\n",
- "print('---------------------------- Main training parameters ----------------------------')\n",
- "print('Number of epochs: '+str(number_of_epochs))\n",
- "print('Batch size: '+str(batch_size))\n",
- "print('Number of training dataset: '+str(number_of_training_dataset))\n",
- "print('Number of training steps: '+str(number_of_steps))\n",
- "print('Number of validation steps: '+str(validation_steps))\n",
- "print('---------------------------- ------------------------ ----------------------------')\n",
- "\n",
- "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0Dfn8ZsEMv5d"
- },
- "source": [
- "## **4.2. Start Training**\n",
- "---\n",
- "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n",
- "\n",
- "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n",
- "\n",
- "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "iwNmp1PUzRDQ",
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "#@markdown ##Start training\n",
- "\n",
- "start = time.time()\n",
- "\n",
- "\n",
- "\n",
- "random_choice = random.choice(os.listdir(Patch_source))\n",
- "x = io.imread(os.path.join(Patch_source, random_choice))\n",
- "sample_batch = np.expand_dims(normalizePercentile(x), axis = [0, -1])\n",
- "sample_img = SampleImageCallback(model, sample_batch, os.path.join(model_path, model_name))\n",
- "\n",
- "history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps,\n",
- " epochs = number_of_epochs,\n",
- " callbacks=[model_checkpoint, reduce_lr, sample_img],\n",
- " validation_data = validation_datagen,\n",
- " validation_steps = 3, shuffle=True, verbose=1)\n",
- "\n",
- "# Save the last model\n",
- "model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n",
- "\n",
- "\n",
- "# convert the history.history dict to a pandas DataFrame:\n",
- "lossData = pd.DataFrame(history.history)\n",
- "\n",
- "# The training evaluation.csv is saved (overwrites the Files if needed).\n",
- "lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n",
- "with open(lossDataCSVpath, 'w') as f:\n",
- " writer = csv.writer(f)\n",
- " writer.writerow(['loss','val_loss', 'learning rate'])\n",
- " for i in range(len(history.history['loss'])):\n",
- " writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n",
- "\n",
- "\n",
- "\n",
- "# Displaying the time elapsed for training\n",
- "print(\"------------------------------------------\")\n",
- "dt = time.time() - start\n",
- "mins, sec = divmod(dt, 60)\n",
- "hour, mins = divmod(mins, 60)\n",
- "print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n",
- "print(\"------------------------------------------\")\n",
- "\n",
- "#Create a pdf document with training summary\n",
- "\n",
- "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "_0Hynw3-xHp1"
- },
- "source": [
- "# **5. Evaluate your model**\n",
- "---\n",
- "\n",
- "This section allows the user to perform important quality checks on the validity and generalisability of the trained model.\n",
- "\n",
- "**We highly recommend to perform quality control on all newly trained models.**\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "eAJzMwPA6tlH"
- },
- "outputs": [],
- "source": [
- "#@markdown ###Do you want to assess the model you just trained ?\n",
- "\n",
- "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n",
- "\n",
- "#@markdown ###If not, please provide the path to the model folder:\n",
- "\n",
- "QC_model_folder = \"\" #@param {type:\"string\"}\n",
- "\n",
- "#Here we define the loaded model name and path\n",
- "QC_model_name = os.path.basename(QC_model_folder)\n",
- "QC_model_path = os.path.dirname(QC_model_folder)\n",
- "\n",
- "\n",
- "if (Use_the_current_trained_model):\n",
- " print(\"Using current trained network\")\n",
- " QC_model_name = model_name\n",
- " QC_model_path = model_path\n",
- "else:\n",
- " # These are used in section 6\n",
- " model_name = QC_model_name\n",
- " model_path = QC_model_path\n",
- "\n",
- "full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n",
- "if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n",
- " print(\"The \"+QC_model_name+\" network will be evaluated\")\n",
- "else:\n",
- " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n",
- " print('Please make sure you provide a valid model path and model name before proceeding further.')\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "dhJROwlAMv5o"
- },
- "source": [
- "## **5.1. Inspection of the loss function**\n",
- "---\n",
- "\n",
- "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n",
- "\n",
- "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n",
- "\n",
- "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n",
- "\n",
- "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n",
- "\n",
- "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "vMzSP50kMv5p"
- },
- "outputs": [],
- "source": [
- "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n",
- "\n",
- "epochNumber = []\n",
- "lossDataFromCSV = []\n",
- "vallossDataFromCSV = []\n",
- "\n",
- "with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n",
- " csvRead = csv.reader(csvfile, delimiter=',')\n",
- " next(csvRead)\n",
- " for row in csvRead:\n",
- " lossDataFromCSV.append(float(row[0]))\n",
- " vallossDataFromCSV.append(float(row[1]))\n",
- "\n",
- "epochNumber = range(len(lossDataFromCSV))\n",
- "\n",
- "plt.figure(figsize=(15,10))\n",
- "\n",
- "plt.subplot(2,1,1)\n",
- "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n",
- "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n",
- "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n",
- "plt.ylabel('Loss')\n",
- "plt.xlabel('Epoch number')\n",
- "plt.legend()\n",
- "\n",
- "plt.subplot(2,1,2)\n",
- "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n",
- "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n",
- "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n",
- "plt.ylabel('Loss')\n",
- "plt.xlabel('Epoch number')\n",
- "plt.legend()\n",
- "plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)\n",
- "plt.show()\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "X5_92nL2xdP6"
- },
- "source": [
- "## **5.2. Error mapping and quality metrics estimation**\n",
- "---\n",
- "This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n",
- "\n",
- "The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei.\n",
- "\n",
- "The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n",
- "\n",
- " The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n",
- "\n",
- "### **Thresholds for image masks**\n",
- "\n",
- " Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "w90MdriMxhjD"
- },
- "outputs": [],
- "source": [
- "# ------------- User input ------------\n",
- "#@markdown ##Choose the folders that contain your Quality Control dataset\n",
- "Source_QC_folder = \"\" #@param{type:\"string\"}\n",
- "Target_QC_folder = \"\" #@param{type:\"string\"}\n",
- "\n",
- "\n",
- "# ------------- Initialise folders ------------\n",
- "# Create a quality control/Prediction Folder\n",
- "prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n",
- "if os.path.exists(prediction_QC_folder):\n",
- " shutil.rmtree(prediction_QC_folder)\n",
- "\n",
- "os.makedirs(prediction_QC_folder)\n",
- "\n",
- "\n",
- "# ------------- Prepare the model and run predictions ------------\n",
- "\n",
- "# Load the model\n",
- "unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n",
- "labels = unet.output_shape[-1]\n",
- "Input_size = unet.layers[0].output_shape[0][1:3]\n",
- "print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n",
- "\n",
- "# Create a list of sources\n",
- "source_dir_list = os.listdir(Source_QC_folder)\n",
- "number_of_dataset = len(source_dir_list)\n",
- "print('Number of dataset found in the folder: '+str(number_of_dataset))\n",
- "\n",
- "predictions = []\n",
- "for i in tqdm(range(number_of_dataset)):\n",
- " predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n",
- "\n",
- "\n",
- "# Save the results in the folder along with the masks according to the set threshold\n",
- "saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix)\n",
- "\n",
- "#-----------------------------Calculate Metrics----------------------------------------#\n",
- "\n",
- "# Here we start testing the differences between GT and predicted masks\n",
- "\n",
- "with open(QC_model_path+'/'+QC_model_name+'/Quality Control/QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n",
- " writer = csv.writer(file, delimiter=\",\")\n",
- " stats_columns = [\"image\"]\n",
- "\n",
- " for l in range(labels):\n",
- " stats_columns.append(\"Prediction v. GT IoU label = {}\".format(l))\n",
- " stats_columns.append(\"Prediction v. GT averaged IoU\")\n",
- " writer.writerow(stats_columns)\n",
- " # Initialise the lists\n",
- " filename_list = []\n",
- " iou_score_list = []\n",
- " for filename in os.listdir(Source_QC_folder):\n",
- " if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n",
- " print('Running QC on: '+filename)\n",
- " test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n",
- " test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n",
- " test_prediction = io.imread(os.path.join(prediction_QC_folder, prediction_prefix + filename))\n",
- "\n",
- " iou_labels = [filename]\n",
- " iou_score = 0.\n",
- " for l in range(labels):\n",
- " aux_gt = (test_ground_truth_image==l).astype(np.uint8)\n",
- " aux_pred = (test_prediction==l).astype(np.uint8)\n",
- " intersection = np.logical_and(aux_gt, aux_pred)\n",
- " union = np.logical_or(aux_gt, aux_pred)\n",
- "\n",
- " iou_labels.append(str(np.sum(intersection) / np.sum(union)))\n",
- " iou_score += np.sum(intersection) / np.sum(union)\n",
- " filename_list.append(filename)\n",
- " iou_score_list.append(iou_score/labels)\n",
- " iou_labels.append(str(iou_score/labels))\n",
- " writer.writerow(iou_labels)\n",
- " file.close()\n",
- "\n",
- "## Create a display of the results\n",
- "\n",
- "# Table with metrics as dataframe output\n",
- "pdResults = pd.DataFrame(index = filename_list)\n",
- "pdResults[\"IoU\"] = iou_score_list\n",
- "\n",
- "# ------------- For display ------------\n",
- "print('--------------------------------------------------------------')\n",
- "@interact\n",
- "def show_QC_results(file=os.listdir(Source_QC_folder)):\n",
- "\n",
- " plt.figure(figsize=(25,5))\n",
- " #Input\n",
- " plt.subplot(1,4,1)\n",
- " plt.axis('off')\n",
- " plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n",
- " plt.title('Input')\n",
- "\n",
- " #Ground-truth\n",
- " plt.subplot(1,4,2)\n",
- " plt.axis('off')\n",
- " test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n",
- " plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n",
- " plt.title('Ground Truth')\n",
- "\n",
- " #Prediction\n",
- " plt.subplot(1,4,3)\n",
- " plt.axis('off')\n",
- " test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n",
- " plt.imshow(test_prediction, aspect='equal', cmap='Purples')\n",
- " plt.title('Prediction')\n",
- "\n",
- " #Overlay\n",
- " plt.subplot(1,4,4)\n",
- " plt.axis('off')\n",
- " plt.imshow(test_ground_truth_image, cmap='Greens')\n",
- " plt.imshow(test_prediction, alpha=0.5, cmap='Purples')\n",
- " metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ')'\n",
- " plt.title(metrics_title)\n",
- " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n",
- "\n",
- "qc_pdf_export()\n",
- "\n",
- "pdResults.head()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Uyr7jZWuPJKG"
- },
- "source": [
- "## **5.3. Export your model into the BioImage Model Zoo format**\n",
- "---\n",
- "This section exports the model into the [BioImage Model Zoo](https://bioimage.io/#/) format so it can be used directly with deepImageJ or Ilastik. The new files will be stored in the model folder specified at the beginning of Section 5.\n",
- "\n",
- "Once the cell is executed, you will find a new zip file with the name specified in `trained_model_name.bioimage.io.model`.\n",
- "\n",
- "To use it with deepImageJ, download it and install it suing DeepImageJ Install Model > Install from a local file.\n",
- "\n",
- "To try the model in ImageJ, go to Plugins > DeepImageJ > DeepImageJ Run, choose this model from the list and click on Test Model.\n",
- "\n",
- "The exported model contains an additional ImageJ macro (`Contours2InstanceSegmentation.ijm`) to obtain a unique 2D image with the different labels and also identify each independent object in the image.\n",
- "\n",
- " More information at https://deepimagej.github.io/deepimagej/"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "2NHJam91PT3U"
- },
- "outputs": [],
- "source": [
- "# ------------- User input ------------\n",
- "# information about the model\n",
- "#@markdown ##Insert the information to document your model:\n",
- "Trained_model_name = \"\" #@param {type:\"string\"}\n",
- "Trained_model_description = \"\" #@param {type:\"string\"}\n",
- "\n",
- "#@markdown ###Author(s) - insert information separated by commas:\n",
- "Trained_model_author_1 = \"Author 1 name, *Author 1 affiliation, *Author 1 email, *Author 1 ORCID, *Author 1 Github User\" #@param {type:\"string\"}\n",
- "Trained_model_author_2 = \"Author 2 name, *Author 2 affiliation, *Author 2 email, *Author 2 ORCID, *Author 2 Github User\" #@param {type:\"string\"}\n",
- "\n",
- "# @markdown ###Model Packager:\n",
- "packager_same_as_author = True #@param {type:\"boolean\"}\n",
- "#@markdown - If not, please, provide the following information:\n",
- "Trained_model_packager = \"Packager name, *Packager affiliation, *Packager email, *Packager ORCID, *Packager Github User\" #@param {type:\"string\"}\n",
- "\n",
- "# @markdown ###Model Maintainer:\n",
- "maintainer_same_as_author = True #@param {type:\"boolean\"}\n",
- "#@markdown - If not, please, provide the following information:\n",
- "Trained_model_maintainer = \"Maintainer Github User, *Maintainer name, *Maintainer affiliation, *Maintainer email, *Maintainer ORCID\" #@param {type:\"string\"}\n",
- "\n",
- "# @markdown ###License:\n",
- "Trained_model_license = 'CC-BY-4.0' #@param {type:\"string\"}\n",
- "\n",
- "Trained_model_references = [\"Falk et al. Nature Methods 2019\", \"Ronneberger et al. arXiv in 2015\", \"Lucas von Chamier et al. biorXiv 2020\"]\n",
- "Trained_model_DOI = [\"https://doi.org/10.1038/s41592-018-0261-2\",\"https://doi.org/10.1007/978-3-319-24574-4_28\", \"https://doi.org/10.1101/2020.03.20.000133\"]\n",
- "\n",
- "# Training data\n",
- "# ---------------------------------------\n",
- "#@markdown ##Include information about training data (optional):\n",
- "include_training_data = False #@param {type: \"boolean\"}\n",
- "#@markdown ### - If it is published in the BioImage Model Zoo, please, provide the ID\n",
- "data_from_bioimage_model_zoo = False #@param {type: \"boolean\"}\n",
- "training_data_ID = ''#@param {type:\"string\"}\n",
- "#@markdown ### - If not, please provide the URL to the data and a short description to be added to the README.md file\n",
- "training_data_source = ''#@param {type:\"string\"}\n",
- "training_data_description = ''#@param {type:\"string\"}\n",
- "\n",
- "# Add input image information\n",
- "# ---------------------------------------\n",
- "#@markdown ##Indicate the minimum x/y size of the image (in pixels) and step size (in pixels) to be used for block/tiling:\n",
- "# information about the example image\n",
- "min_size = 64 #@param {type:\"number\"}\n",
- "step_size = 16 #@param {type:\"number\"}\n",
- "#@markdown ##Do you want to choose the example image?\n",
- "default_example_image = True #@param {type:\"boolean\"}\n",
- "#@markdown ###If not, please input:\n",
- "fileID = \"\" #@param {type:\"string\"}\n",
- "if default_example_image:\n",
- " fileID = os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0])\n",
- "\n",
- "# Load the model and process the example image\n",
- "# ---------------------------------------\n",
- "# Load the model\n",
- "model = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'),\n",
- " custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n",
- "\n",
- "# ------------- Execute bioimage model zoo configuration ------------\n",
- "# Create a model without compilation so it can be used in any other environment.\n",
- "# remove the custom loss function from the model, so that it can be used outside of this notebook\n",
- "unet = Model(model.input, model.output)\n",
- "weight_path = os.path.join(full_QC_model_path, 'keras_weights.hdf5')\n",
- "unet.save(weight_path)\n",
- "\n",
- "# training data source\n",
- "if data_from_bioimage_model_zoo:\n",
- " training_data = {'id' : training_data_ID}\n",
- "else:\n",
- " training_data = None\n",
- "\n",
- "# create the author/maintainer/packager spec input\n",
- "author_1_spec = make_author(Trained_model_author_1)\n",
- "authors = [author_1_spec]\n",
- "\n",
- "# check if author 2 was filled\n",
- "if 'Author 2 name' not in Trained_model_author_2:\n",
- " author_2_spec = make_author(Trained_model_author_2)\n",
- " authors.append(author_2_spec)\n",
- "\n",
- "if packager_same_as_author:\n",
- " packager_spec = author_1_spec\n",
- "else:\n",
- " packager_spec = make_author(Trained_model_packager)\n",
- "\n",
- "if maintainer_same_as_author:\n",
- " if author_1_spec.github_user != None:\n",
- " maintainer_from_author = [str(author_1_spec.github_user), str(author_1_spec.name), str(author_1_spec.affiliation), str(author_1_spec.email), str(author_1_spec.orcid)]\n",
- " maintainer_str = ', '.join(maintainer_from_author)\n",
- " maintainer_spec = make_maintainer(maintainer_str)\n",
- " else:\n",
- " print('Please, provide the author GitHub username in the author information')\n",
- "else:\n",
- " maintainer_spec = make_maintainer(Trained_model_maintainer)\n",
- "\n",
- "\n",
- "# I would recommend using CCBY-4 as licence\n",
- "license = Trained_model_license\n",
- "\n",
- "# where to save the model\n",
- "output_root = os.path.join(full_QC_model_path, Trained_model_name + '.bioimage.io.model')\n",
- "os.makedirs(output_root, exist_ok=True)\n",
- "output_path = os.path.join(output_root, f\"{Trained_model_name}.zip\")\n",
- "\n",
- "# create a markdown readme with information\n",
- "documentation_path = os.path.join(output_root, \"README.md\")\n",
- "with open(documentation_path, \"w\") as f:\n",
- " f.write(\"Visit https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki \\n\\n This was an automatically generated README.md. \\n\\n\")\n",
- "\n",
- "# create the citation input spec\n",
- "assert len(Trained_model_DOI) == len(Trained_model_references)\n",
- "citations = [{'text': text, 'doi': doi.replace('https://doi.org/', '')} for text, doi in zip(Trained_model_references, Trained_model_DOI)]\n",
- "citation_spec = [bioimageio_spec.CiteEntry(**c) for c in citations]\n",
- "\n",
- "# create the training data\n",
- "if include_training_data:\n",
- " if data_from_bioimage_model_zoo:\n",
- " training_data = {\"id\": training_data_ID}\n",
- " else:\n",
- " with open(documentation_path, \"a\") as f:\n",
- " f.write(f'Training data: {training_data_source} \\n\\n and description: {training_data_description} \\n\\n')\n",
- " training_data = None\n",
- "else:\n",
- " training_data = None\n",
- "\n",
- "\n",
- "# load the input image, crop it if necessary, and save as numpy file\n",
- "# The crop will be centered to get an image with some content.\n",
- "input_img = io.imread(fileID, as_gray = True).astype(np.float32)\n",
- "assert input_img.ndim == 2,'Example input image is not a 2D grayscale image. Please, provide a 2D grayscale image.'\n",
- "\n",
- "# batch should never be constrained\n",
- "shape = [sh for sh in unet.input.shape]\n",
- "assert shape[0] is None\n",
- "shape[0] = 1 # batch is set to 1 for bioimage.io\n",
- "assert all(sh is not None for sh in shape) # make sure all other shapes are fixed\n",
- "\n",
- "test_img = input_img\n",
- "\n",
- "x_size = int(test_img.shape[0]/2)\n",
- "x_size = x_size-int(shape[1]/2)\n",
- "\n",
- "y_size = int(test_img.shape[1]/2)\n",
- "y_size = y_size-int(shape[2]/2)\n",
- "assert test_img.ndim == 2\n",
- "test_img = test_img[x_size : x_size + shape[1],\n",
- " y_size : y_size + shape[2]]\n",
- "\n",
- "# Save the test image\n",
- "test_input_path = os.path.join(output_root, \"test_input.npy\")\n",
- "np.save(test_input_path, test_img[None, ..., None])\n",
- "\n",
- "# run prediction on the input image and save the result as expected output\n",
- "test_img = normalizePercentile(test_img)\n",
- "test_img = test_img[None, ..., None]\n",
- "test_prediction = unet.predict(test_img)\n",
- "test_prediction = np.squeeze(test_prediction)\n",
- "assert test_prediction.ndim == 3\n",
- "\n",
- "shape_pred = test_prediction.shape\n",
- "channel_pred_idx = shape_pred.index(min(shape_pred))\n",
- "n_channels = shape_pred[channel_pred_idx]\n",
- "\n",
- "test_prediction = test_prediction[None, ...]\n",
- "test_output_path = os.path.join(output_root, \"test_output.npy\")\n",
- "np.save(test_output_path, test_prediction)\n",
- "\n",
- "# create the channel names for the output\n",
- "channel_names = []\n",
- "\n",
- "for idx in range(n_channels):\n",
- " channel_names.append(f'channel{idx}')\n",
- "\n",
- "# create the input tensor\n",
- "input_tensor = bioimageio_spec.InputTensorDescr(id=bioimageio_spec.TensorId('input0'),\n",
- " description= 'This is the test input tensor created from the example image.',\n",
- " axes=[bioimageio_spec.BatchAxis(id='batch', description='', type='batch', size=None),\n",
- " bioimageio_spec.SpaceInputAxis(size=bioimageio_spec.ParameterizedSize(min=min_size, step=step_size), id='y', description='', type='space', unit=None, scale=1.0, concatenable=False),\n",
- " bioimageio_spec.SpaceInputAxis(size=bioimageio_spec.ParameterizedSize(min=min_size, step=step_size), id='x', description='', type='space', unit=None, scale=1.0, concatenable=False),\n",
- " bioimageio_spec.ChannelAxis(id='channel', description='', type='channel', channel_names=['channel0'])],\n",
- " test_tensor = bioimageio_spec.FileDescr(source = test_input_path),\n",
- " preprocessing = [bioimageio_spec.EnsureDtypeDescr(kwargs=bioimageio_spec.EnsureDtypeKwargs(dtype=\"float32\")),\n",
- " bioimageio_spec.ScaleRangeDescr(kwargs=bioimageio_spec.ScaleRangeKwargs(axes = ['x','y'], min_percentile = 1.0 , max_percentile = 99.8) ),\n",
- " ],\n",
- " )\n",
- "\n",
- "\n",
- "\n",
- "# create the output tensor\n",
- "output_tensor = bioimageio_spec.OutputTensorDescr( axes=[bioimageio_spec.BatchAxis(id='batch', description='', type='batch', size=None),\n",
- " bioimageio_spec.SpaceOutputAxis(size=bioimageio_spec.SizeReference(tensor_id=bioimageio_spec.TensorId('input0'), axis_id='y', offset=0), id='y', description='', type='space', unit=None, scale=1.0),\n",
- " bioimageio_spec.SpaceOutputAxis( size=bioimageio_spec.SizeReference(tensor_id=bioimageio_spec.TensorId('input0'), axis_id='x', offset=0), id='x', description='', type='space', unit=None, scale=1.0),\n",
- " bioimageio_spec.ChannelAxis(id='channel', description='', type='channel', channel_names=channel_names)],\n",
- " test_tensor = bioimageio_spec.FileDescr(source = test_output_path) )\n",
- "\n",
- "attachments = []\n",
- "# attach the QC report to the model (if it exists)\n",
- "qc_path = os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv')\n",
- "if os.path.exists(qc_path):\n",
- " attachments.append(FileDescr(source = qc_path))\n",
- "\n",
- "# Include a post-processing deepImageJ macro\n",
- "macro = \"Contours2InstanceSegmentation.ijm\"\n",
- "url = f\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/{macro}\"\n",
- "path = os.path.join(output_root, macro)\n",
- "with requests.get(url, stream=True) as r:\n",
- " text = r.text\n",
- " if text.startswith(\"4\"):\n",
- " raise RuntimeError(f\"An error occured when downloading {url}: {r.text}\")\n",
- " with open(path, \"w\") as f:\n",
- " f.write(r.text)\n",
- "attachments.append(FileDescr(source = path))\n",
- "\n",
- "# make cover image\n",
- "cover = np.squeeze(test_img)\n",
- "pred_cover = np.squeeze(test_prediction)\n",
- "\n",
- "for idx in range(1, n_channels):\n",
- " if channel_pred_idx == 0:\n",
- " cover = np.concatenate((cover, pred_cover[idx,:,:]), axis=1)\n",
- " elif channel_pred_idx == 1:\n",
- " cover = np.concatenate((cover, pred_cover[:,idx,:]), axis=1)\n",
- " elif channel_pred_idx == 2:\n",
- " cover = np.concatenate((cover, pred_cover[:,:,idx]), axis=1)\n",
- "\n",
- "cover_path = os.path.join(output_root, \"cover.png\")\n",
- "plt.imsave(cover_path, cover, cmap='gray')\n",
- "\n",
- "# make weights description\n",
- "unet_tf_weights = bioimageio_spec.KerasHdf5WeightsDescr(source=weight_path, tensorflow_version=tf.__version__)\n",
- "unet_weights = bioimageio_spec.WeightsDescr(keras_hdf5=unet_tf_weights)\n",
- "\n",
- "# create model description for export\n",
- "model_description = bioimageio_spec.ModelDescr(name=Trained_model_name,\n",
- " description=Trained_model_description,\n",
- " covers=[cover_path],\n",
- " authors=authors,\n",
- " attachments=attachments,\n",
- " cite=citation_spec,\n",
- " license=license,\n",
- "\n",
- " maintainers=[maintainer_spec],\n",
- " tags=['zerocostdl4mic', 'deepimagej', 'segmentation', 'unet'],\n",
- " documentation= documentation_path,\n",
- " inputs=[input_tensor],\n",
- " outputs=[output_tensor],\n",
- " packaged_by=[packager_spec],\n",
- " weights=unet_weights,\n",
- " training_data=training_data,\n",
- "\n",
- " )\n",
- "\n",
- "\n",
- "# test model\n",
- "summary = bioimageio_core.test_model(model_description, weight_format=\"keras_hdf5\")\n",
- "summary.display()\n",
- "\n",
- "success = summary.status == \"passed\"\n",
- "\n",
- "save_bioimageio_package(model_description, output_path=Path(output_path))\n",
- "\n",
- "if success:\n",
- " print(\"The bioimage.io model was successfully exported to\", output_path)\n",
- "else:\n",
- " print(\"The bioimage.io model was exported to\", output_path)\n",
- " print(\"Some tests of the model did not work! You can still download and test the model.\")\n",
- " print(\"You can still download and test the model, but it may not work as expected.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-tJeeJjLnRkP"
- },
- "source": [
- "# **6. Using the trained model**\n",
- "\n",
- "---\n",
- "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "d8wuQGjoq6eN"
- },
- "source": [
- "## **6.1 Generate prediction(s) from unseen dataset**\n",
- "---\n",
- "\n",
- "The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n",
- "\n",
- "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n",
- "\n",
- "**`Result_folder`:** This folder will contain the predicted output images.\n",
- "\n",
- " Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n",
- "\n",
- " **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "colab": {
- "background_save": true
- },
- "id": "y2TD5p7MZrEb"
- },
- "outputs": [],
- "source": [
- "\n",
- "\n",
- "# ------------- Initial user input ------------\n",
- "#@markdown ###Provide the path to your dataset and to the folder where the predicted masks will be saved (Result folder), then play the cell to predict the output on your unseen images and store it.\n",
- "Data_folder = '' #@param {type:\"string\"}\n",
- "Results_folder = '' #@param {type:\"string\"}\n",
- "\n",
- "#@markdown ###Do you want to use the current trained model?\n",
- "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n",
- "\n",
- "#@markdown ###If not, please provide the path to the model folder:\n",
- "\n",
- "Prediction_model_folder = \"\" #@param {type:\"string\"}\n",
- "\n",
- "#Here we find the loaded model name and parent path\n",
- "Prediction_model_name = os.path.basename(Prediction_model_folder)\n",
- "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n",
- "\n",
- "\n",
- "# ------------- Failsafes ------------\n",
- "if (Use_the_current_trained_model):\n",
- " print(\"Using current trained network\")\n",
- " Prediction_model_name = model_name\n",
- " Prediction_model_path = model_path\n",
- "\n",
- "full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n",
- "if os.path.exists(full_Prediction_model_path):\n",
- " print(\"The \"+Prediction_model_name+\" network will be used.\")\n",
- "else:\n",
- " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n",
- " print('Please make sure you provide a valid model path and model name before proceeding further.')\n",
- "\n",
- "\n",
- "# ------------- Prepare the model and run predictions ------------\n",
- "\n",
- "# Load the model and prepare generator\n",
- "\n",
- "\n",
- "\n",
- "unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n",
- "Input_size = unet.layers[0].output_shape[0][1:3]\n",
- "print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n",
- "\n",
- "# Create a list of sources\n",
- "source_dir_list = os.listdir(Data_folder)\n",
- "number_of_dataset = len(source_dir_list)\n",
- "print('Number of dataset found in the folder: '+str(number_of_dataset))\n",
- "\n",
- "predictions = []\n",
- "for i in tqdm(range(number_of_dataset)):\n",
- " predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n",
- " # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n",
- "\n",
- "\n",
- "# Save the results in the folder along with the masks according to the set threshold\n",
- "saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix)\n",
- "\n",
- "\n",
- "# ------------- For display ------------\n",
- "print('--------------------------------------------------------------')\n",
- "\n",
- "\n",
- "def show_prediction_mask(file=os.listdir(Data_folder)):\n",
- "\n",
- " plt.figure(figsize=(10,6))\n",
- " # Wide-field\n",
- " plt.subplot(1,2,1)\n",
- " plt.axis('off')\n",
- " img_Source = plt.imread(os.path.join(Data_folder, file))\n",
- " plt.imshow(img_Source, cmap='gray')\n",
- " plt.title('Source image',fontsize=15)\n",
- " # Prediction\n",
- " plt.subplot(1,2,2)\n",
- " plt.axis('off')\n",
- " img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n",
- " plt.imshow(img_Prediction, cmap='gray')\n",
- " plt.title('Prediction',fontsize=15)\n",
- "\n",
- "interact(show_prediction_mask);\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hvkd66PldsXB"
- },
- "source": [
- "## **6.2. Download your predictions**\n",
- "---\n",
- "\n",
- "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "BphZ0wBrC2Zw"
- },
- "source": [
- "# **7. Version log**\n",
- "\n",
- "---\n",
- "**v2.1.3**: \n",
- "\n",
- "* Updated Bioimage.IO model export to latest version (core-0.6.9, spec-0.5.3.2)\n",
- "* Fixed model importation from Bioimage.IO\n",
- "* Fixed Tensorflow version to 2.15\n",
- "* Bug fixes\n",
- "\n",
- "**v2.1.2**: \n",
- "\n",
- "* Correct for data loading to avoid .DS_Store or similar\n",
- "\n",
- "**v2.1.1**: \n",
- "\n",
- "* Replaced all absolute pathing with relative pathing\n",
- "\n",
- "**v2.1**:\n",
- "* Updated to TensorFlow 2.11\n",
- "* Updated to `fpdf2` and add lines to ensure a proper format. Correct keras package version parsing.\n",
- "---"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "UvSlTaH14s3t"
- },
- "source": [
- "# **Thank you for using 2D U-Net multilabel segmentation!**"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "gpuType": "T4",
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.19"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}