diff --git a/Colab_notebooks_Beta/DeepSTORM_2D_ZeroCostDL4Mic.ipynb.json b/Colab_notebooks_Beta/DeepSTORM_2D_ZeroCostDL4Mic.ipynb.json deleted file mode 100644 index 80706381..00000000 --- a/Colab_notebooks_Beta/DeepSTORM_2D_ZeroCostDL4Mic.ipynb.json +++ /dev/null @@ -1 +0,0 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"DeepSTORM_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4","colab_type":"text"},"source":["# **Deep-STORM 2D**\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is traiend using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors duting training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension). \n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y","colab_type":"text"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX","colab_type":"text"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz","colab_type":"text"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn","colab_type":"text"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-","colab_type":"text"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ","colab_type":"text"},"source":["# **2. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Install Deep-STORM and dependencies\n","\n","# %% Model definition + helper functions\n","\n","# Import keras modules and libraries\n","from keras.models import Model\n","from keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization\n","from keras.callbacks import Callback\n","from keras import backend as K\n","from keras import optimizers\n","from keras import losses\n","from keras.preprocessing.image import ImageDataGenerator\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","\n","# Import common libraries\n","import tensorflow as tf\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image\n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","from sys import getsizeof\n","import math\n","from astropy.visualization import simple_norm\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n"," \n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr=0.001)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size=0.3, random_state=42)\n"," print('Number of Training Examples: %d' % X_train.shape[0])\n"," print('Number of Validation Examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_model.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1))\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \\\n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \\\n"," validation_data=(X_test_norm, Y_test), \\\n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('Training Completed!')\n"," \n"," # # plot the loss function progression during training\n"," # loss = train_history.history['loss']\n"," # val_loss = train_history.history['val_loss']\n"," # plt.figure()\n"," # plt.plot(loss)\n"," # plt.plot(val_loss)\n"," # plt.legend(['train_loss', 'val_loss'])\n"," # plt.xlabel('Iteration #')\n"," # plt.ylabel('Loss Function')\n"," # plt.title(\"Loss function progress during training\")\n"," # plt.show()\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i]])\n","\n"," \n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test}\n"," sio.savemat(os.path.join(modelPath,\"meanstd_model.mat\"), mdict)\n"," return\n","\n","\n","def test_model(dataPath, filename, modelPath, savePath, \\\n"," upsampling_factor=8, debug=0, display=True):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs (TO UPDATE ----------)\n"," datafile - the tiff stack of test images \n"," weights_file - the saved weights file generated in train_model\n"," meanstd_file - the saved mean and standard deviation file generated in train_model\n"," savename - the filename for saving the recovered SR image\n"," upsampling_factor - the upsampling factor for reconstruction (default 8)\n"," debug - boolean whether to save individual frame predictions (default 0)\n"," \n"," # Outputs\n"," function saves a mat file with the recovered image, and optionally saves \n"," individual frame predictions in case debug=1. (default is debug=0) \n"," \"\"\"\n"," \n"," # load the tiff data\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," \n"," # get dataset dimensions\n"," (K, M, N) = Images.shape\n"," \n"," # upsampling using a simple nearest neighbor interp.\n"," Images_upsampled = np.zeros((K, M*upsampling_factor, N*upsampling_factor))\n"," for i in range(Images.shape[0]):\n"," Images_upsampled[i,:,:] = np.kron(Images[i,:,:], np.ones((upsampling_factor,upsampling_factor))) \n"," Images = Images_upsampled\n"," \n"," # upsampled frames dimensions\n"," (K, M, N) = Images.shape\n"," \n"," # Build the model for a bigger image\n"," model = buildModel((M, N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_model.hdf5'))\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'meanstd_model.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n","\n"," # Setting type\n"," Images = Images.astype('float32')\n"," \n"," # Normalize each sample by it's own mean and std\n"," Images_norm = np.zeros(Images.shape,dtype=np.float32)\n"," for i in range(Images.shape[0]):\n"," Images_norm[i,:,:] = project_01(Images[i,:,:])\n"," Images_norm[i,:,:] = normalize_im(Images_norm[i,:,:], test_mean, test_std)\n","\n"," # Reshaping\n"," Images_norm = np.expand_dims(Images_norm,axis=3) \n"," \n"," # Make a prediction and time it\n"," # start = time.time()\n"," predicted_density = model.predict(Images_norm, batch_size=1)\n"," # end = time.time()\n"," \n"," # threshold negative values\n"," predicted_density[predicted_density < 0] = 0\n"," \n"," # resulting sum images\n"," WideField = np.squeeze(np.sum(Images_norm, axis=0))\n"," Recovery = np.squeeze(np.sum(predicted_density, axis=0))\n"," \n"," if (display):\n"," # Look at the sum image\n"," f, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True)\n"," ax1.imshow(WideField)\n"," ax1.set_title('Wide Field')\n"," ax1.axis('off')\n"," ax2.imshow(Recovery)\n"," ax2.set_title('Sum of Predictions')\n"," ax2.axis('off')\n"," f.subplots_adjust(hspace=0)\n"," plt.show()\n"," \n"," # Save predictions to a matfile to open later in matlab\n"," # mdict = {\"Recovery\": Recovery}\n"," # sio.savemat(savename+'Prediction.mat', mdict)\n"," # io.imsave(savename+'PredictionStack.tif', predicted_density)\n"," io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'),Recovery)\n","\n"," # io.imsave(os.path.join(savePath, filename'Prediction.tif'), Recovery)\n"," io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'),WideField)\n"," # io.imsave(savePath+'Widefield.tif', WideField)\n","\n"," # save predicted density in each frame for debugging purposes\n"," if debug:\n"," mdict = {\"Predictions\": predicted_density}\n"," sio.savemat(savename + '_debug_save.mat', mdict)\n"," \n"," return\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_MultiThreaded(xc_array, yc_array, photon_array, sigma_array, image_size = 64, pixel_size = 100):\n"," Image = np.zeros((image_size, image_size))\n"," for ij in prange(image_size*image_size):\n"," j = int(ij/image_size)\n"," i = ij - j*image_size\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (photon > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," Image[j][i] += 0.25*photon*ErfX*ErfY\n"," return Image\n","\n","\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos","colab_type":"text"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l","colab_type":"text"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file."]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Load raw data\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","(number_of_frames, M, N) = Images.shape\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","with Image.open(ImageData_path) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","pixel_size = meta_dict['XResolution'][0][0]/100000\n","print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n","\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n","# LocData.head()\n","# print(LocData['sigma [nm]'].sum()/6459)\n","# print(LocData['sigma [nm]'].min())\n","# print(LocData['sigma [nm]'].max())\n","\n","# print(LocData['intensity [photon]'].sum()/6459)\n","# print(LocData['intensity [photon]'].min())\n","# print(LocData['intensity [photon]'].max())\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9","colab_type":"text"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with `Sigma = 0.21 x Lambda / NA`\n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame\n","* The `n_photons` and `wavelength` can additionally include some Gaussian variability by setting `wavelength_std` and `n_photons_std`\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400 #@param {type:\"number\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","ADC_per_photon_conversion = 1.0 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 2 #@param {type:\"number\"}\n","ADC_offset = 100#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","number_of_frames = 20#@param {type:\"integer\"}\n","NA = 1.4 #@param {type:\"number\"}\n","wavelength = 700#@param {type:\"number\"}\n","wavelength_std = 50#@param {type:\"number\"}\n","n_photons = 5000 #@param {type:\"number\"}\n","n_photons_std = 500#@param {type:\"number\"}\n","\n","# # @markdown Parameters about the ground truth image display:\n","# upsampling_factor = 8 #@param {type:\"number\"}\n","# sigma_heatmap = 10 #@param {type:\"number\"}\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = round(emitter_density*FOV_size*FOV_size/10**6)\n","print('Number of molecules / FOV: '+str(n_molecules))\n","\n","sigma = 0.21*wavelength/NA\n","sigma_std = 0.21*wavelength_std/NA\n","print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","print('Final image size: '+str(M)+'x'+str(M))\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","# heatmaps = np.zeros((number_of_frames, upsampling_factor*M, upsampling_factor*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_molecules)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_molecules)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_molecules)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_molecules)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," # Get the approximated locations according to the grid pixel size\n"," Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # Build Localization image\n"," for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," locImage[f][r][c] += 1\n","\n"," # NoiseFreeImages[f] = FromLoc2Image_MultiThreaded(x_c, y_c, M, pixel_size, sigma)\n"," # heatmaps[f] = FromLoc2Image_MultiThreaded(x_c, y_c, upsampling_factor*M, pixel_size/upsampling_factor, sigma_heatmap)\n"," NoiseFreeImages[f] = FromLoc2Image_MultiThreaded(x_c, y_c, photon_array, sigma_array, M, pixel_size)\n"," # heatmaps[f] = FromLoc2Image_MultiThreaded2(x_c, y_c, photon_array, sigma_heatmap*np.ones(n_molecules), upsampling_factor*image_size, pixel_size/upsampling_factor)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","\n","# Convert to 16-bit\n","Images[Images > 65535] = 65535\n","Images = Images.astype(np.uint16)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,4,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," # plt.subplot(1,4,2)\n"," # plt.imshow(heatmaps[frame-1], interpolation='nearest')\n"," # plt.title('Heatmap')\n"," # plt.axis('off');\n","\n"," plt.subplot(1,4,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,4,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","colab_type":"code","cellView":"form","colab":{}},"source":["# @markdown ---\n","# @markdown #Play this cell to save the simulated stack\n","# @markdown ####Please select a path to the folder where to save the simulated data.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY","colab_type":"text"},"source":["## **3.2. Generate training patches**\n","---\n"]},{"cell_type":"code","metadata":{"id":"_8WWVYr7k47F","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size,patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size,patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size,patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' pacthes of '+str(patch_size)+'x'+str(patch_size))\n","\n","\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(Images):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n"," HeatMapImage = gaussian_filter(LocImage, gaussian_sigma)\n","\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = 100*HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+W)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","print('Size of patches: '+str(dataSize)+' MB')\n","print(str(k-1)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","# print(patches.shape)\n","# print(patches.dtype)\n","# print(patches.min())\n","# print(patches.max())\n","\n","\n","# print(heatmaps.shape)\n","# print(heatmaps.dtype)\n","# print(heatmaps.min())\n","# print(heatmaps.max())\n","\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz","colab_type":"text"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx","colab_type":"text"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters."]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","\n","epochs = 100#@param {type:\"integer\"}\n","steps_per_epoch = 400#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","# Use_TrainingSetMATLAB_file = False #@param {type:\"boolean\"}\n","# Training_source = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/demo 1 - Simulated Microtubules/TrainingSet.mat\" #@param {type: \"string\"}\n","\n","# if Use_TrainingSetMATLAB_file:\n","# # Load training data and divide it to training and validation sets\n","# matfile = h5py.File(Training_source, 'r')\n","# patches = np.array(matfile['patches'])\n","# heatmaps = 100.0 * np.array(matfile['heatmaps'])\n","\n","if not ('patches' in locals()):\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: No patches were found in memory currently . !!'+W)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA","colab_type":"text"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Start Training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","Save_path = model_path +'/'+model_name\n","os.makedirs(Save_path)\n","\n","train_model(patches, heatmaps, Save_path, steps_per_epoch=steps_per_epoch, epochs=epochs, batch_size=batch_size)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"CHVTRjEOLRDH","colab_type":"text"},"source":["##**4.2. Download your model(s) from Google Drive**\n","\n","\n","---\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr","colab_type":"text"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","colab_type":"code","cellView":"form","colab":{}},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter 'model_name' (see section 3). Provide the name of this folder as 'QC_model_name' and the path to its parent folder in 'QC_model_path'. \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4","colab_type":"text"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","colab_type":"code","cellView":"form","colab":{}},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'))\n","plt.show()\n","\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY","colab_type":"text"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------------------ Definitions specific to QC ------------------------\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, \"QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(os.listdir(QC_image_folder), os.listdir(QC_loc_folder)):\n"," if not os.path.isdir(os.path.join(QC_image_folder,imageFilename)):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," # Get the GT image from the localizations on the right pixel grid\n"," with Image.open(os.path.join(QC_image_folder,imageFilename)) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," N = meta_dict['ImageWidth'][0]\n"," M = meta_dict['ImageLength'][0]\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n","\n"," pixelSize = meta_dict['XResolution'][0][0]/100000\n"," print('Pixel size obtained from metadata: '+str(pixelSize)+' nm')\n"," pixel_size_hr = pixelSize/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," locImage = np.zeros((Mhr,Nhr), dtype = np.float64)\n"," x = list(LocData['x [nm]'])\n"," y = list(LocData['y [nm]'])\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(x[i]/pixel_size_hr),Nhr-1),0)) for i in range(len(x))]\n"," Rhr_emitters = [int(max(min(round(y[i]/pixel_size_hr),Mhr-1),0)) for i in range(len(y))]\n","\n"," # Build Localization image\n"," for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," locImage[r][c] += 1\n","\n"," # Save the loc image files\n"," io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(QC_image_folder)):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT)\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n","\n","print('--------------------------------------------')\n","pdResults.head()\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd","colab_type":"text"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"Wkyd9gFrkmbN","colab_type":"text"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"MUBOYdAfUtnA","colab_type":"code","cellView":"form","colab":{}},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown ###Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","for filename in os.listdir(Data_folder):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," test_model(Data_folder, filename, prediction_model_path, Result_folder, display = False);\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","print('--------------------------------------------------------------------')\n","\n","@interact\n","def show_QC_results(file=os.listdir(Data_folder)):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","\n"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n","colab_type":"text"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV","colab_type":"text"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file