From e00806c091fe9c84faa6bffa5760b1bd9d63a04f Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Sun, 11 Jun 2023 12:45:58 +0200 Subject: [PATCH] Voronoi-Otsu labeling + instance segmentation code overhaul (#34) * Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling * Disabled small removal in Voronoi-Otsu * Added new docs for instance seg * Docs + UI update - Updated welcome/README - Changed step for DoubleCounter * Update requirements.txt Fix typo * isort * Fix tests * Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init * Fix inference * Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black * black * Complete instance method evaluation * Added pre-commit hooks * Enfore pre-commit style * Update .gitignore * Version bump * Updated project files * Fixed missing parent error * Fixed wrong value in instance sliders * Removing dask-image * Fixed erroneous dtype conversion * Update test_plugin_utils.py * Temporary test action patch * Update plugin_convert.py * Update tox.ini Added pocl for testing on GH Actions * Update tox.ini * Found existing pocl * Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime * Relabeling tests * Run full suite of pre-commit hooks * Enforce style * Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling * Disabled small removal in Voronoi-Otsu * Added new docs for instance seg * isort * Fix tests * Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init * Fix inference * Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black * black * Complete instance method evaluation * Enfore pre-commit style * Removing dask-image * Fixed erroneous dtype conversion * Update test_plugin_utils.py * Update tox.ini * Added new pre-commit hooks * Run full suite of pre-commit hooks * Enforce style * Documentation update, crop contrast fix * Update plugin_model_inference.py * Updated hooks --------- Co-authored-by: gityves <114951621+gityves@users.noreply.github.com> --- .github/workflows/test_and_deploy.yml | 2 + .gitignore | 5 + .pre-commit-config.yaml | 13 +- README.md | 5 +- docs/res/code/model_instance_seg.rst | 23 + docs/res/code/plugin_convert.rst | 5 - docs/res/guides/cropping_module_guide.rst | 6 +- docs/res/guides/utils_module_guide.rst | 10 +- docs/res/welcome.rst | 36 +- napari_cellseg3d/_tests/pytest.ini | 2 + napari_cellseg3d/_tests/res/test_labels.tif | Bin 0 -> 2026 bytes .../_tests/test_labels_correction.py | 52 ++ napari_cellseg3d/_tests/test_plugin_utils.py | 18 +- .../_tests/test_weight_download.py | 6 +- .../code_models/model_framework.py | 15 +- .../code_models/model_instance_seg.py | 349 ++++++++- napari_cellseg3d/code_models/model_workers.py | 134 ++-- .../code_models/models/unet/buildingblocks.py | 5 +- .../code_models/models/unet/model.py | 8 +- napari_cellseg3d/code_plugins/plugin_base.py | 3 +- .../code_plugins/plugin_convert.py | 312 ++------ napari_cellseg3d/code_plugins/plugin_crop.py | 15 +- .../code_plugins/plugin_helper.py | 6 +- .../code_plugins/plugin_metrics.py | 3 +- .../code_plugins/plugin_model_inference.py | 49 +- .../code_plugins/plugin_model_training.py | 23 +- .../code_plugins/plugin_review.py | 19 +- .../code_plugins/plugin_review_dock.py | 11 +- .../code_plugins/plugin_utilities.py | 22 +- napari_cellseg3d/config.py | 34 +- .../dev_scripts/artefact_labeling.py | 433 +++++++++++ napari_cellseg3d/dev_scripts/convert.py | 4 +- .../dev_scripts/correct_labels.py | 369 ++++++++++ napari_cellseg3d/dev_scripts/drafts.py | 3 +- .../dev_scripts/evaluate_labels.py | 693 ++++++++++++++++++ napari_cellseg3d/dev_scripts/thread_test.py | 16 +- napari_cellseg3d/interface.py | 82 ++- napari_cellseg3d/utils.py | 126 ++-- notebooks/assess_instance.ipynb | 503 +++++++++++++ notebooks/csv_cell_plot.ipynb | 2 - pyproject.toml | 44 +- requirements.txt | 7 +- tox.ini | 2 +- 43 files changed, 2857 insertions(+), 618 deletions(-) create mode 100644 napari_cellseg3d/_tests/pytest.ini create mode 100644 napari_cellseg3d/_tests/res/test_labels.tif create mode 100644 napari_cellseg3d/_tests/test_labels_correction.py create mode 100644 napari_cellseg3d/dev_scripts/artefact_labeling.py create mode 100644 napari_cellseg3d/dev_scripts/correct_labels.py create mode 100644 napari_cellseg3d/dev_scripts/evaluate_labels.py create mode 100644 notebooks/assess_instance.ipynb diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 5dcd11ae..ea0a1e46 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -8,12 +8,14 @@ on: branches: - main - npe2 + - cy/voronoi-otsu tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - npe2 + - cy/voronoi-otsu workflow_dispatch: jobs: diff --git a/.gitignore b/.gitignore index d08ff9f2..df43b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,8 @@ notebooks/full_plot.html *.csv *.png *.prof + +#include test data +!napari_cellseg3d/_tests/res/test.tif +!napari_cellseg3d/_tests/res/test.png +!napari_cellseg3d/_tests/res/test_labels.tif diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7053663e..f9fe2853 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,11 +5,14 @@ repos: # - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", --line-length=79] + - id: check-yaml + - id: check-added-large-files + - id: check-toml +# - repo: https://github.com/pycqa/isort +# rev: 5.12.0 +# hooks: +# - id: isort +# args: ["--profile", "black", --line-length=79] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. rev: 'v0.0.262' diff --git a/README.md b/README.md index 25d02fa5..ece6c6f4 100644 --- a/README.md +++ b/README.md @@ -151,8 +151,9 @@ Distributed under the terms of the [MIT] license. ## Acknowledgements -This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). - +This plugin was developed by Cyril Achard, Maxime Vidal, Mackenzie Mathis. +This work was funded, in part, from the Wyss Center to the [Mathis Laboratory of Adaptive Motor Control](https://www.mackenziemathislab.org/). +Please refer to the documentation for full acknowledgements. ## Plugin base This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst index e4146ec1..3b323173 100644 --- a/docs/res/code/model_instance_seg.rst +++ b/docs/res/code/model_instance_seg.rst @@ -1,6 +1,29 @@ model_instance_seg.py =========================================== +Classes +------------- + +InstanceMethod +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::InstanceMethod + :members: __init__ + +ConnectedComponents +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::ConnectedComponents + :members: __init__ + +Watershed +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::Watershed + :members: __init__ + +VoronoiOtsu +************************************** +.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::VoronoiOtsu + :members: __init__ + Functions ------------- diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index c7dc7df9..03944510 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -19,11 +19,6 @@ ToSemanticUtils .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToSemanticUtils :members: __init__ -InstanceWidgets -********************************** -.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::InstanceWidgets - :members: __init__, run_method - ToInstanceUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToInstanceUtils diff --git a/docs/res/guides/cropping_module_guide.rst b/docs/res/guides/cropping_module_guide.rst index a862ffff..89cbb39a 100644 --- a/docs/res/guides/cropping_module_guide.rst +++ b/docs/res/guides/cropping_module_guide.rst @@ -33,9 +33,9 @@ If you'd like to change the size of the volume, change the parameters as previou Creating new layers --------------------------------- -To "zoom in" your volume, you can use the "Create new layers" checkbox to make a new layer not controlled by the plugin next -time you hit Start. This way, you can first select your region of interest by using the tool as described above, -the enable the option, select the cropped layer, and define a smaller crop size to have easier access to your region of interest. +To "zoom in" your volume, you can use the "Create new layers" checkbox to make a new cropping layer controlled by the sliders +next time you hit Start. This way, you can first select your region of interest by using the tool as described above, +then enable the option, select the cropped region produced before as the input layer, and define a smaller crop size in order to crop within your region of interest. Interface & functionalities --------------------------------------------------------------- diff --git a/docs/res/guides/utils_module_guide.rst b/docs/res/guides/utils_module_guide.rst index 407ae710..64e8a3ce 100644 --- a/docs/res/guides/utils_module_guide.rst +++ b/docs/res/guides/utils_module_guide.rst @@ -4,13 +4,21 @@ Label conversion utility guide ================================== This utility will let you convert labels to various different formats. + You will have to specify the results directory for saving; afterwards you can run each action on a folder or on the currently selected layer. You can : +* Crop 3D volumes : + Please refer to :ref:`cropping_module_guide` for a guide on using the cropping utility. + * Convert to instance labels : - This will convert 0/1 semantic labels to instance label, with a unique ID for each object using the watershed method. + This will convert 0/1 semantic labels to instance label, with a unique ID for each object. + The available methods for this are : + * Connected components : simple method that will assign a unique ID to each connected component. Does not work well for touching objects (objects will often be fused), works for anisotropic volumes. + * Watershed : method based on topographic maps. Works well for touching objects and anisotropic volumes; touching objects may be fused. + * Voronoi-Otsu : method based on Voronoi diagrams. Works well for touching objects but only for isotropic volumes. * Convert to semantic labels : This will convert instance labels with unique IDs per object into 0/1 semantic labels, for example for training. diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 6832e71e..892549a8 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -38,22 +38,28 @@ You can install `napari-cellseg3d` via [pip]: ``pip install napari-cellseg3d`` - For local installation, please run: +For local installation after cloning, please run in the CellSeg3D folder: ``pip install -e .`` Requirements -------------------------------------------- +.. note:: + A **CUDA-capable GPU** is not needed but **very strongly recommended**, especially for training and possibly inference. + .. important:: - A **CUDA-capable GPU** is not needed but **very strongly recommended**, especially for training. + This package requires you have napari installed with PyQt5 or PySide2 first. + If you do not have a Qt backend you can use : -This package requires you have napari installed first. + ``pip install napari-cellseg3d[all]`` + to install PyQt5 by default. -It also depends on PyTorch and some optional dependencies of MONAI. These come in the pip package above, but if +It also depends on PyTorch and some optional dependencies of MONAI. These come in the pip package as requirements, but if you need further assistance see below. * For help with PyTorch, please see `PyTorch's website`_ for installation instructions, with or without CUDA depending on your hardware. + Depending on your setup, you might wish to install torch first. * If you get errors from MONAI regarding missing readers, please see `MONAI's optional dependencies`_ page for instructions on getting the readers required by your images. @@ -70,14 +76,13 @@ To use the plugin, please run: Then go into Plugins > napari-cellseg3d, and choose which tool to use: - - **Review**: This module allows you to review your labels, from predictions or manual labeling, and correct them if needed. It then saves the status of each file in a csv, for easier monitoring - **Inference**: This module allows you to use pre-trained segmentation algorithms on volumes to automatically label cells - **Training**: This module allows you to train segmentation algorithms from labeled volumes - **Utilities**: This module allows you to use several utilities, e.g. to crop your volumes and labels, compute prediction scores or convert labels - **Help/About...** : Quick access to version info, Github page and docs -See above for links to detailed guides regarding the usage of the modules. +See the documentation for links to detailed guides regarding the usage of the modules. Acknowledgments & References --------------------------------------------- @@ -90,24 +95,29 @@ We also provide a model that was trained in-house on mesoSPIM nuclei data in col This plugin mainly uses the following libraries and software: -* `napari website`_ +* `napari`_ -* `PyTorch website`_ +* `PyTorch`_ -* `MONAI project website`_ (various models used here are credited `on their website`_) +* `MONAI project`_ (various models used here are credited `on their website`_) +* `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase + +* A custom re-implementation of the `WNet model`_ by Xia and Kulis [#]_ .. _Mathis Laboratory of Adaptive Motor Control: http://www.mackenziemathislab.org/ .. _Wyss Center: https://wysscenter.ch/ .. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP -.. _napari website: https://napari.org/ -.. _PyTorch website: https://pytorch.org/ -.. _MONAI project website: https://monai.io/ +.. _napari: https://napari.org/ +.. _PyTorch: https://pytorch.org/ +.. _MONAI project: https://monai.io/ .. _on their website: https://docs.monai.io/en/stable/networks.html#nets - +.. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype +.. _WNet model: https://arxiv.org/abs/1711.08506 .. rubric:: References .. [#] Mapping mesoscale axonal projections in the mouse brain using a 3D convolutional network, Friedmann et al., 2020 ( https://pnas.org/cgi/doi/10.1073/pnas.1918465117 ) .. [#] The mesoSPIM initiative: open-source light-sheet microscopes for imaging cleared tissue, Voigt et al., 2019 ( https://doi.org/10.1038/s41592-019-0554-0 ) .. [#] MONAI Project website ( https://monai.io/ ) +.. [#] W-Net: A Deep Model for Fully Unsupervised Image Segmentation, Xia and Kulis, 2018 ( https://arxiv.org/abs/1711.08506 ) diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini new file mode 100644 index 00000000..45c3be1c --- /dev/null +++ b/napari_cellseg3d/_tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +qt_api=pyqt5 diff --git a/napari_cellseg3d/_tests/res/test_labels.tif b/napari_cellseg3d/_tests/res/test_labels.tif new file mode 100644 index 0000000000000000000000000000000000000000..0486d789ea658acc32616b40869833accf8d01d7 GIT binary patch literal 2026 zcmcK5yJ}QX6b9gPW+oRk7ZaUC6EDMfA4AYa#WzSNSc-*3f&q(wHX^pxK7x-TK7$V+ zh-hgcUQztNum?}HQam9)Yn^rd*V>ys8yll)x~i)As;YZc9c?nG8+xbi?%D^jcZTs? z;A_lNkw1zQI~mBsOB}G_#)iYWBJ~`{jpd=(^f$j8o7U4Q+J#zTzQ_J1_z;*uteC68 z$zUP)5+6=ue*NfXR{vChyE>l&{pDQ@Rs-|ldNz=U59v(k#{#wVv17fKBh6$ldYFA! zj2TD_w8=&Bcb7Z$2DI=OnVnep55ES3J|ZCRZ7^|q`;Z@w+f_hcAf uO7Fq{VSFjiRU3?7w#N8*ON^i72d14J-{`ip<7-oGF@Dt&<6PlCcKj3h 0: + for i in range(num_sliders): + widget = f"slider_{i}" + setattr( + self, + widget, + ui.Slider( + 0, + 100, + 1, + divide_factor=100, + text_label="", + parent=None, + ), + ) + self.sliders.append(getattr(self, widget)) + + if num_counters > 0: + for i in range(num_counters): + widget = f"counter_{i}" + setattr( + self, + widget, + ui.DoubleIncrementCounter(label="", parent=None), + ) + self.counters.append(getattr(self, widget)) + + def run_method(self, image): + raise NotImplementedError("Must be defined in child classes") + @dataclass class ImageStats: @@ -50,18 +109,48 @@ def get_dict(self): def threshold(volume, thresh): + """Remove all values smaller than the specified threshold in the volume""" im = np.squeeze(volume) binary = im > thresh return np.where(binary, im, np.zeros_like(im)) +def voronoi_otsu( + volume: np.ndarray, + spot_sigma: float, + outline_sigma: float, + # remove_small_size: float, +): + """ + Voronoi-Otsu labeling from pyclesperanto. + BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase + https://github.com/clEsperanto/napari_pyclesperanto_assistant + + Args: + volume (np.ndarray): volume to segment + spot_sigma (float): parameter determining how close detected objects can be + outline_sigma (float): determines the smoothness of the segmentation + + Returns: + Instance segmentation labels from Voronoi-Otsu method + + """ + # remove_small_size (float): remove all objects smaller than the specified size in pixels + # semantic = np.squeeze(volume) + logger.debug( + f"Running voronoi otsu segmentation with spot_sigma={spot_sigma} and outline_sigma={outline_sigma}" + ) + instance = cle.voronoi_otsu_labeling( + volume, spot_sigma=spot_sigma, outline_sigma=outline_sigma + ) + # instance = remove_small_objects(instance, remove_small_size) + return np.array(instance) + + def binary_connected( - volume, + volume: np.array, thres=0.5, thres_small=3, - # scale_factors=(1.0, 1.0, 1.0), - *args, - **kwargs, ): r"""Convert binary foreground probability maps to instance masks via connected-component labeling. @@ -71,9 +160,14 @@ def binary_connected( thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) + """ + logger.debug( + f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" + ) + # if len(volume.shape) > 3: semantic = np.squeeze(volume) - foreground = semantic > thres # int(255 * thres) + foreground = np.where(semantic > thres, volume, 0) # int(255 * thres) segm = label(foreground) segm = remove_small_objects(segm, thres_small) @@ -97,28 +191,29 @@ def binary_connected( def binary_watershed( volume, thres_objects=0.3, - thres_small=10, thres_seeding=0.9, - # scale_factors=(1.0, 1.0, 1.0), + thres_small=10, rem_seed_thres=3, - *args, - **kwargs, ): r"""Convert binary foreground probability maps to instance masks via watershed segmentation algorithm. Note: This function uses the `skimage.segmentation.watershed `_ - function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays. + function that converts the input image into ``np.float64`` data type for processing. Therefore, please make sure enough memory is allocated when handling large arrays. Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. - thres_seeding (float): threshold for seeding. Default: 0.98 thres_objects (float): threshold for foreground objects. Default: 0.3 + thres_seeding (float): threshold for seeding. Default: 0.9 thres_small (int): size threshold of small objects removal. Default: 10 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) rem_seed_thres (int): threshold for small seeds removal. Default : 3 + """ + logger.debug( + f"Running watershed segmentation with thres_objects={thres_objects}, thres_seeding={thres_seeding}," + f" thres_small={thres_small} and rem_seed_thres={rem_seed_thres}" + ) semantic = np.squeeze(volume) seed_map = semantic > thres_seeding foreground = semantic > thres_objects @@ -193,7 +288,7 @@ def to_instance(image, is_file_path=False): result = binary_watershed( image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 - ) # TODO add params + ) # FIXME add params from utils plugin return result @@ -283,3 +378,211 @@ def fill(lst, n=len(properties) - 1): ratio, fill([len(properties)]), ) + + +class Watershed(InstanceMethod): + """Widget class for Watershed segmentation. Requires 4 parameters, see binary_watershed""" + + def __init__(self, widget_parent=None): + super().__init__( + name=WATERSHED, + function=binary_watershed, + num_sliders=2, + num_counters=2, + widget_parent=widget_parent, + ) + + self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[ + 0 + ].tooltips = "Probability threshold for foreground object" + self.sliders[0].setValue(50) + + self.sliders[1].text_label.setText("Seed probability threshold") + self.sliders[1].tooltips = "Probability threshold for seeding" + self.sliders[1].setValue(90) + + self.counters[0].label.setText("Small object removal") + self.counters[0].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + self.counters[0].setValue(30) + + self.counters[1].label.setText("Small seed removal") + self.counters[1].tooltips = ( + "Volume/size threshold for small seeds removal." + "\nAll seeds with a volume/size below this value will be removed." + ) + self.counters[1].setValue(3) + + def run_method(self, image): + return self.function( + image, + self.sliders[0].slider_value, + self.sliders[1].slider_value, + self.counters[0].value(), + self.counters[1].value(), + ) + + +class ConnectedComponents(InstanceMethod): + """Widget class for Connected Components instance segmentation. Requires 2 parameters, see binary_connected.""" + + def __init__(self, widget_parent=None): + super().__init__( + name=CONNECTED_COMP, + function=binary_connected, + num_sliders=1, + num_counters=1, + widget_parent=widget_parent, + ) + + self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[ + 0 + ].tooltips = "Probability threshold for foreground object" + self.sliders[0].setValue(80) + + self.counters[0].label.setText("Small objects removal") + self.counters[0].tooltips = ( + "Volume/size threshold for small object removal." + "\nAll objects with a volume/size below this value will be removed." + ) + self.counters[0].setValue(3) + + def run_method(self, image): + return self.function( + image, self.sliders[0].slider_value, self.counters[0].value() + ) + + +class VoronoiOtsu(InstanceMethod): + """Widget class for Voronoi-Otsu labeling from pyclesperanto. Requires 2 parameter, see voronoi_otsu""" + + def __init__(self, widget_parent=None): + super().__init__( + name=VORONOI_OTSU, + function=voronoi_otsu, + num_sliders=0, + num_counters=2, + widget_parent=widget_parent, + ) + self.counters[0].label.setText("Spot sigma") # closeness + self.counters[ + 0 + ].tooltips = "Determines how close detected objects can be" + self.counters[0].setMaximum(100) + self.counters[0].setValue(2) + + self.counters[1].label.setText("Outline sigma") # smoothness + self.counters[ + 1 + ].tooltips = "Determines the smoothness of the segmentation" + self.counters[1].setMaximum(100) + self.counters[1].setValue(2) + + # self.counters[2].label.setText("Small object removal") + # self.counters[2].tooltips = ( + # "Volume/size threshold for small object removal." + # "\nAll objects with a volume/size below this value will be removed." + # ) + # self.counters[2].setValue(30) + + def run_method(self, image): + ################ + # For debugging + # import napari + # view = napari.Viewer() + # view.add_image(image) + # napari.run() + ################ + + return self.function( + image, + self.counters[0].value(), + self.counters[1].value(), + # self.counters[2].value(), + ) + + +class InstanceWidgets(QWidget): + """ + Base widget with several sliders, for use in instance segmentation parameters + """ + + def __init__(self, parent=None): + """ + Creates an InstanceWidgets widget + + Args: + parent: parent widget + + """ + super().__init__(parent) + self.method_choice = ui.DropdownMenu( + list(INSTANCE_SEGMENTATION_METHOD_LIST.keys()) + ) + self.methods = {} + """Contains the instance of the method, with its name as key""" + self.instance_widgets = {} + """Contains the lists of widgets for each methods, to show/hide""" + + self.method_choice.currentTextChanged.connect(self._set_visibility) + self._build() + + def _build(self): + group = ui.GroupedWidget("Instance segmentation") + group.layout.addWidget(self.method_choice) + + try: + for name, method in INSTANCE_SEGMENTATION_METHOD_LIST.items(): + method_class = method(widget_parent=self.parent()) + self.methods[name] = method_class + self.instance_widgets[name] = [] + # moderately unsafe way to init those widgets ? + if len(method_class.sliders) > 0: + for slider in method_class.sliders: + group.layout.addWidget(slider.container) + self.instance_widgets[name].append(slider) + if len(method_class.counters) > 0: + for counter in method_class.counters: + group.layout.addWidget(counter.label) + group.layout.addWidget(counter) + self.instance_widgets[name].append(counter) + except RuntimeError as e: + logger.debug( + f"Caught runtime error {e}, most likely during testing" + ) + + self.setLayout(group.layout) + self._set_visibility() + + def _set_visibility(self): + for name in self.instance_widgets: + if name != self.method_choice.currentText(): + for widget in self.instance_widgets[name]: + widget.set_visibility(False) + else: + for widget in self.instance_widgets[name]: + widget.set_visibility(True) + + def run_method(self, volume): + """ + Calls instance function with chosen parameters + + Args: + volume: image data to run method on + + Returns: processed image from self._method + + """ + method = self.methods[self.method_choice.currentText()] + return method.run_method(volume) + + +INSTANCE_SEGMENTATION_METHOD_LIST = { + VORONOI_OTSU: VoronoiOtsu, + WATERSHED: Watershed, + CONNECTED_COMP: ConnectedComponents, +} diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/model_workers.py index c5675a11..30d37bbd 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -2,44 +2,46 @@ from dataclasses import dataclass from math import ceil from pathlib import Path -from typing import List -from typing import Optional +from typing import List, Optional import numpy as np import torch # MONAI -from monai.data import CacheDataset -from monai.data import DataLoader -from monai.data import Dataset -from monai.data import decollate_batch -from monai.data import pad_list_data_collate -from monai.data import PatchDataset +from monai.data import ( + CacheDataset, + DataLoader, + Dataset, + PatchDataset, + decollate_batch, + pad_list_data_collate, +) from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric -from monai.transforms import AddChannel -from monai.transforms import AsDiscrete -from monai.transforms import Compose -from monai.transforms import EnsureChannelFirstd -from monai.transforms import EnsureType -from monai.transforms import EnsureTyped -from monai.transforms import LoadImaged -from monai.transforms import Orientationd -from monai.transforms import Rand3DElasticd -from monai.transforms import RandAffined -from monai.transforms import RandFlipd -from monai.transforms import RandRotate90d -from monai.transforms import RandShiftIntensityd -from monai.transforms import RandSpatialCropSamplesd -from monai.transforms import SpatialPad -from monai.transforms import SpatialPadd -from monai.transforms import ToTensor -from monai.transforms import Zoom +from monai.transforms import ( + AddChannel, + AsDiscrete, + Compose, + EnsureChannelFirstd, + EnsureType, + EnsureTyped, + LoadImaged, + Orientationd, + Rand3DElasticd, + RandAffined, + RandFlipd, + RandRotate90d, + RandShiftIntensityd, + RandSpatialCropSamplesd, + SpatialPad, + SpatialPadd, + ToTensor, + Zoom, +) from monai.utils import set_determinism # threads -from napari.qt.threading import GeneratorWorker -from napari.qt.threading import WorkerBaseSignals +from napari.qt.threading import GeneratorWorker, WorkerBaseSignals # Qt from qtpy.QtCore import Signal @@ -47,17 +49,12 @@ from tqdm import tqdm # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_connected, -) from napari_cellseg3d.code_models.model_instance_seg import ( - binary_watershed, + ImageStats, + volume_stats, ) -from napari_cellseg3d.code_models.model_instance_seg import ImageStats -from napari_cellseg3d.code_models.model_instance_seg import volume_stats logger = utils.LOGGER @@ -118,7 +115,7 @@ def show_progress(count, block_size, total_size): with open(json_path) as f: neturls = json.load(f) - if model_name in neturls.keys(): + if model_name in neturls: url = neturls[model_name] response = urllib.request.urlopen(url) @@ -288,10 +285,11 @@ def log_parameters(self): f"Thresholding is enabled at {config.post_process_config.thresholding.threshold_value}" ) - if config.sliding_window_config.is_enabled(): - status = "enabled" - else: - status = "disabled" + status = ( + "enabled" + if config.sliding_window_config.is_enabled() + else "disabled" + ) self.log(f"Window inference is {status}\n") if status == "enabled": @@ -315,9 +313,7 @@ def log_parameters(self): instance_config = config.post_process_config.instance if instance_config.enabled: self.log( - f"Instance segmentation enabled, method : {instance_config.method}\n" - f"Probability threshold is {instance_config.threshold.threshold_value:.2f}\n" - f"Objects smaller than {instance_config.small_object_removal_threshold.threshold_value} pixels will be removed\n" + f"Instance segmentation enabled, method : {instance_config.method.name}\n" ) self.log("-" * 20) @@ -389,7 +385,7 @@ def load_folder(self): return inference_loader def load_layer(self): - self.log("Loading layer\n") + self.log("\nLoading layer\n") data = np.squeeze(self.config.layer) volume = np.array(data, dtype=np.int16) @@ -457,16 +453,14 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - def model_output(inputs): return post_process_transforms( self.config.model_info.get_model().get_output(model, inputs) ) - if self.config.keep_on_cpu: - dataset_device = "cpu" - else: - dataset_device = self.config.device + dataset_device = ( + "cpu" if self.config.keep_on_cpu else self.config.device + ) window_size = self.config.sliding_window_config.window_size window_overlap = self.config.sliding_window_config.window_overlap @@ -551,7 +545,9 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): i + 1, ) if from_layer: - instance_labels = np.swapaxes(instance_labels, 0, 2) + instance_labels = np.swapaxes( + instance_labels, 0, 2 + ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -604,30 +600,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - threshold = ( - self.config.post_process_config.instance.threshold.threshold_value - ) - size_small = ( - self.config.post_process_config.instance.small_object_removal_threshold.threshold_value - ) - method_name = self.config.post_process_config.instance.method - - if method_name == "Watershed": # FIXME use dict in config instead - - def method(image): - return binary_watershed(image, threshold, size_small) - - elif method_name == "Connected components": - - def method(image): - return binary_connected(image, threshold, size_small) - - else: - raise NotImplementedError( - "Selected instance segmentation method is not defined" - ) - - instance_labels = method(to_instance) + method = self.config.post_process_config.instance.method + instance_labels = method.run_method(image=to_instance) instance_filepath = ( self.config.results_path @@ -1084,10 +1058,7 @@ def train(self): do_sampling = self.config.sampling if model_name == "SegResNet": - if do_sampling: - size = self.config.sample_size - else: - size = check + size = self.config.sample_size if do_sampling else check logger.info(f"Size of image : {size}") model = model_class.get_net( input_image_size=utils.get_padding_dim(size), @@ -1095,10 +1066,7 @@ def train(self): # dropout_prob=0.3, ) elif model_name == "SwinUNetR": - if do_sampling: - size = self.sample_size - else: - size = check + size = self.sample_size if do_sampling else check logger.info(f"Size of image : {size}") model = model_class.get_net( img_size=utils.get_padding_dim(size), diff --git a/napari_cellseg3d/code_models/models/unet/buildingblocks.py b/napari_cellseg3d/code_models/models/unet/buildingblocks.py index 4cdc0a43..73913ab8 100644 --- a/napari_cellseg3d/code_models/models/unet/buildingblocks.py +++ b/napari_cellseg3d/code_models/models/unet/buildingblocks.py @@ -64,10 +64,7 @@ def create_conv( ) elif char == "g": is_before_conv = i < order.index("c") - if is_before_conv: - num_channels = in_channels - else: - num_channels = out_channels + num_channels = in_channels if is_before_conv else out_channels # use only one group if the given number of groups is greater than the number of channels if num_channels < num_groups: diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py index c5cc78d3..9591d054 100644 --- a/napari_cellseg3d/code_models/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -1,14 +1,10 @@ import torch.nn as nn from napari_cellseg3d.code_models.models.unet.buildingblocks import ( + DoubleConv, create_decoders, -) -from napari_cellseg3d.code_models.models.unet.buildingblocks import ( create_encoders, ) -from napari_cellseg3d.code_models.models.unet.buildingblocks import ( - DoubleConv, -) def number_of_features_per_level(init_channel_number, num_levels): @@ -68,7 +64,7 @@ def __init__( f_maps, num_levels=num_levels ) - assert isinstance(f_maps, list) or isinstance(f_maps, tuple) + assert isinstance(f_maps, (list, tuple)) assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" # create encoder path diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 5191e66f..0a613ee7 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -6,8 +6,7 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QTabWidget -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QTabWidget, QWidget # local from napari_cellseg3d import interface as ui diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 5560b7b9..6c8370c1 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -4,18 +4,16 @@ import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QWidget -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite import napari_cellseg3d.interface as ui -from napari_cellseg3d import config from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceWidgets, clear_small_objects, + threshold, + to_semantic, ) -from napari_cellseg3d.code_models.model_instance_seg import threshold -from napari_cellseg3d.code_models.model_instance_seg import to_semantic from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder # TODO break down into multiple mini-widgets @@ -38,7 +36,7 @@ def save_folder(results_path, folder_name, images, image_paths): image_paths: list of filenames of images """ results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False) + results_folder.mkdir(exist_ok=False, parents=True) for file, image in zip(image_paths, images): path = results_folder / Path(file).name @@ -147,14 +145,14 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) isotropic_image = utils.resize(data, zoom) save_layer( @@ -169,18 +167,19 @@ def _start(self): f"isotropic_{layer.name}", ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - utils.resize(np.array(imread(file), dtype=np.int16), zoom) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"isotropic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + utils.resize(np.array(imread(file)), zoom) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class RemoveSmallUtils(BasePluginFolder): @@ -246,14 +245,14 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.size_for_removal_counter.value() if self.layer_choice: if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) removed = self.function(data, remove_size) save_layer( @@ -264,18 +263,19 @@ def _start(self): show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - clear_small_objects(file, remove_size, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"small_removed_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) return @@ -328,13 +328,13 @@ def _build(self): ) def _start(self): - Path(self.results_path).mkdir(exist_ok=True) + Path(self.results_path).mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: layer = self.label_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) semantic = to_semantic(data) save_layer( @@ -345,171 +345,19 @@ def _start(self): show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - to_semantic(file, is_file_path=True) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"semantic_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) - - -class InstanceWidgets(QWidget): - """ - Base widget with several sliders, for use in instance segmentation parameters - """ - - def __init__(self, parent=None): - """ - Creates an InstanceWidgets widget - - Args: - parent: parent widget - """ - super().__init__(parent) - - self.method_choice = ui.DropdownMenu( - config.INSTANCE_SEGMENTATION_METHOD_LIST.keys() - ) - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[ - self.method_choice.currentText() - ] - - self.method_choice.currentTextChanged.connect(self._show_connected) - self.method_choice.currentTextChanged.connect(self._show_watershed) - - self.threshold_slider1 = ui.Slider( - lower=0, - upper=100, - default=50, - divide_factor=100.0, - step=5, - text_label="Probability threshold :", - ) - """Base prob. threshold""" - self.threshold_slider2 = ui.Slider( - lower=0, - upper=100, - default=90, - divide_factor=100.0, - step=5, - text_label="Probability threshold (seeding) :", - ) - """Second prob. thresh. (seeding)""" - - self.counter1 = ui.IntIncrementCounter( - upper=100, - default=10, - step=5, - label="Small object removal (pxs) :", - ) - """Small obj. rem.""" - - self.counter2 = ui.IntIncrementCounter( - upper=100, - default=3, - step=5, - label="Small seed removal (pxs) :", - ) - """Small seed rem.""" - - self._build() - - def run_method(self, volume): - """ - Calls instance function with chosen parameters - Args: - volume: image data to run method on - - Returns: processed image from self._method - """ - return self._method( - volume, - self.threshold_slider1.slider_value, - self.counter1.value(), - self.threshold_slider2.slider_value, - self.counter2.value(), - ) - - def _build(self): - group = ui.GroupedWidget("Instance segmentation") - - ui.add_widgets( - group.layout, - [ - self.method_choice, - self.threshold_slider1.container, - self.threshold_slider2.container, - self.counter1.label, - self.counter1, - self.counter2.label, - self.counter2, - ], - ) - - self.setLayout(group.layout) - self._set_tooltips() - - def _set_tooltips(self): - self.method_choice.setToolTip( - "Choose which method to use for instance segmentation" - "\nConnected components : all separated objects will be assigned an unique ID. " - "Robust but will not work correctly with adjacent/touching objects\n" - "Watershed : assigns objects ID based on the probability gradient surrounding an object. " - "Requires the model to surround objects in a gradient;" - " can possibly correctly separate unique but touching/adjacent objects." - ) - self.threshold_slider1.tooltips = ( - "All objects below this probability will be ignored (set to 0)" - ) - self.counter1.setToolTip( - "Will remove all objects smaller (in volume) than the specified number of pixels" - ) - self.threshold_slider2.tooltips = ( - "All seeds below this probability will be ignored (set to 0)" - ) - self.counter2.setToolTip( - "Will remove all seeds smaller (in volume) than the specified number of pixels" - ) - - def _show_watershed(self): - name = "Watershed" - if self.method_choice.currentText() == name: - self._show_slider1() - self._show_slider2() - self._show_counter1() - self._show_counter2() - - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] - - def _show_connected(self): - name = "Connected components" - if self.method_choice.currentText() == name: - self._show_slider1() - self._show_slider2(False) - self._show_counter1() - self._show_counter2(False) - - self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] - - def _show_slider1(self, is_visible: bool = True): - self.threshold_slider1.container.setVisible(is_visible) - - def _show_slider2(self, is_visible: bool = True): - self.threshold_slider2.container.setVisible(is_visible) - - def _show_counter1(self, is_visible: bool = True): - self.counter1.setVisible(is_visible) - self.counter1.label.setVisible(is_visible) - - def _show_counter2(self, is_visible: bool = True): - self.counter2.setVisible(is_visible) - self.counter2.label.setVisible(is_visible) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ToInstanceUtils(BasePluginFolder): @@ -534,7 +382,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.data_panel = self._build_io_panel() self.label_layer_loader.set_layer_type(napari.layers.Layer) - self.instance_widgets = InstanceWidgets() + self.instance_widgets = InstanceWidgets(parent=self) self.start_btn = ui.Button("Start", self._start) @@ -566,13 +414,13 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: layer = self.label_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) instance = self.instance_widgets.run_method(data) save_layer( @@ -584,18 +432,19 @@ def _start(self): instance, name=f"instance_{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.instance_widgets.run_method(imread(file)) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"instance_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.instance_widgets.run_method(imread(file)) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) class ThresholdUtils(BasePluginFolder): @@ -660,14 +509,14 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True) + self.results_path.mkdir(exist_ok=True, parents=True) remove_size = self.binarize_counter.value() if self.layer_choice: if self.image_layer_loader.layer_data() is not None: layer = self.image_layer_loader.layer() - data = np.array(layer.data, dtype=np.int16) + data = np.array(layer.data) removed = self.function(data, remove_size) save_layer( @@ -678,18 +527,19 @@ def _start(self): show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) - elif self.folder_choice.isChecked(): - if len(self.images_filepaths) != 0: - images = [ - self.function(imread(file), remove_size) - for file in self.images_filepaths - ] - save_folder( - self.results_path, - f"threshold_results_{utils.get_date_time()}", - images, - self.images_filepaths, - ) + elif ( + self.folder_choice.isChecked() and len(self.images_filepaths) != 0 + ): + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) # class ConvertUtils(BasePluginFolder): diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 97485bf4..9830d51e 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -11,9 +11,7 @@ # local from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_plugins.plugin_base import ( - BasePluginSingleImage, -) +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage DEFAULT_CROP_SIZE = 64 logger = utils.LOGGER @@ -177,8 +175,8 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 400]) - self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Expanding) + ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() # def _check_results_path(self, folder): @@ -235,10 +233,7 @@ def quicksave(self): def _check_ready(self): if self.image_layer_loader.layer_data() is not None: if self.crop_second_image: - if self.label_layer_loader.layer_data() is not None: - return True - else: - return False + return self.label_layer_loader.layer_data() is not None return True return False @@ -308,7 +303,7 @@ def _start(self): else: self.image_layer1.opacity = 0.7 self.image_layer1.colormap = "inferno" - self.image_layer1.contrast_limits = [200, 1000] # TODO generalize + # self.image_layer1.contrast_limits = [200, 1000] # TODO generalize self.image_layer1.refresh() diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index 083b269b..f8ac18ef 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -4,10 +4,8 @@ # Qt from qtpy.QtCore import QSize -from qtpy.QtGui import QIcon -from qtpy.QtGui import QPixmap -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtGui import QIcon, QPixmap +from qtpy.QtWidgets import QVBoxLayout, QWidget # local from napari_cellseg3d import interface as ui diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index b2356526..114025f6 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -5,8 +5,7 @@ FigureCanvasQTAgg as FigureCanvas, ) from matplotlib.figure import Figure -from monai.transforms import SpatialPad -from monai.transforms import ToTensor +from monai.transforms import SpatialPad, ToTensor from tifffile import imread from napari_cellseg3d import interface as ui diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 2693b178..22867343 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -6,13 +6,17 @@ import pandas as pd # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import InferenceResult -from napari_cellseg3d.code_models.model_workers import InferenceWorker -from napari_cellseg3d.code_plugins.plugin_convert import InstanceWidgets +from napari_cellseg3d.code_models.model_instance_seg import ( + InstanceMethod, + InstanceWidgets, +) +from napari_cellseg3d.code_models.model_workers import ( + InferenceResult, + InferenceWorker, +) class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -77,9 +81,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): config.InferenceWorkerConfig() ) """InferenceWorkerConfig class from config.py""" - self.instance_config: config.InstanceSegConfig = ( - config.InstanceSegConfig() - ) + self.instance_config: InstanceMethod """InstanceSegConfig class from config.py""" self.post_process_config: config.PostProcessConfig = ( config.PostProcessConfig() @@ -185,11 +187,12 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) + self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## # instance segmentation widgets - self.instance_widgets = InstanceWidgets(self) + self.instance_widgets = InstanceWidgets(parent=self) self.use_instance_choice = ui.CheckBox( "Run instance segmentation", func=self._toggle_display_instance @@ -276,9 +279,11 @@ def check_ready(self): if self.layer_choice.isChecked(): if self.image_layer_loader.layer_data() is not None: return True - elif self.folder_choice.isChecked(): - if self.image_filewidget.check_ready(): - return True + elif ( + self.folder_choice.isChecked() + and self.image_filewidget.check_ready() + ): + return True return False def _toggle_display_model_input_size(self): @@ -551,17 +556,11 @@ def start(self): threshold_value=self.thresholding_slider.slider_value, ) - instance_thresh_config = config.Thresholding( - threshold_value=self.instance_widgets.threshold_slider1.slider_value - ) - instance_small_object_thresh_config = config.Thresholding( - threshold_value=self.instance_widgets.counter1.value() - ) self.instance_config = config.InstanceSegConfig( enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.method_choice.currentText(), - threshold=instance_thresh_config, - small_object_removal_threshold=instance_small_object_thresh_config, + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], ) self.post_process_config = config.PostProcessConfig( @@ -730,13 +729,15 @@ def on_yield(self, result: InferenceResult): if result.instance_labels is not None: labels = result.instance_labels - method = self.worker_config.post_process_config.instance.method + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) number_cells = ( np.unique(labels.flatten()).size - 1 ) # remove background - name = f"({number_cells} objects)_{method}_instance_labels_{image_id}" + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" viewer.add_labels(labels, name=name) @@ -750,7 +751,7 @@ def on_yield(self, result: InferenceResult): f"Number of instances : {stats.number_objects}" ) - csv_name = f"/{method}_seg_results_{image_id}_{utils.get_date_time()}.csv" + csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" stats_df.to_csv( self.worker_config.results_path + csv_name, index=False, diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index de54b345..cf8e4b85 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -14,23 +14,26 @@ from matplotlib.figure import Figure # MONAI -from monai.losses import DiceCELoss -from monai.losses import DiceFocalLoss -from monai.losses import DiceLoss -from monai.losses import FocalLoss -from monai.losses import GeneralizedDiceLoss -from monai.losses import TverskyLoss +from monai.losses import ( + DiceCELoss, + DiceFocalLoss, + DiceLoss, + FocalLoss, + GeneralizedDiceLoss, + TverskyLoss, +) # Qt from qtpy.QtWidgets import QSizePolicy # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import TrainingReport -from napari_cellseg3d.code_models.model_workers import TrainingWorker +from napari_cellseg3d.code_models.model_workers import ( + TrainingReport, + TrainingWorker, +) NUMBER_TABS = 3 DEFAULT_PATCH_SIZE = 64 diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 0044c8e2..7ed6c549 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -11,17 +11,13 @@ from matplotlib.figure import Figure # Qt -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QLineEdit, QSizePolicy from tifffile import imwrite # local -from napari_cellseg3d import config +from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils -from napari_cellseg3d.code_plugins.plugin_base import ( - BasePluginSingleImage, -) +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager warnings.formatwarning = utils.format_Warning @@ -184,11 +180,10 @@ def check_image_data(self): if cfg.image is None: raise ValueError("Review requires at least one image") - if cfg.labels is not None: - if cfg.image.shape != cfg.labels.shape: - warnings.warn( - "Image and label dimensions do not match ! Please load matching images" - ) + if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: + warnings.warn( + "Image and label dimensions do not match ! Please load matching images" + ) def _prepare_data(self): if self.layer_choice.isChecked(): diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index 8a25d6a6..c09c376f 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,14 +1,12 @@ import warnings -from datetime import datetime -from datetime import timedelta +from datetime import datetime, timedelta from pathlib import Path import napari import pandas as pd # Qt -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QVBoxLayout, QWidget from napari_cellseg3d import interface as ui from napari_cellseg3d import utils @@ -216,10 +214,7 @@ def create_csv(self, label_dir, model_type, filename=None): ) else: # print(self.image_dims[0]) - if self.filename is not None: - filename = self.filename - else: - filename = "image" + filename = self.filename if self.filename is not None else "image" labels = [str(filename) for i in range(self.image_dims[0])] df = pd.DataFrame( diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 9e66213f..5463a4ff 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -2,20 +2,17 @@ # Qt from qtpy.QtCore import qInstallMessageHandler -from qtpy.QtWidgets import QLayout -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QSizePolicy, QVBoxLayout, QWidget # local import napari_cellseg3d.interface as ui -from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils from napari_cellseg3d.code_plugins.plugin_convert import ( + AnisoUtils, RemoveSmallUtils, + ThresholdUtils, + ToInstanceUtils, + ToSemanticUtils, ) -from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils -from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { @@ -60,10 +57,10 @@ def _build(self): layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) - layout.setSizeConstraint(QLayout.SetFixedSize) + # layout.setSizeConstraint(QLayout.SetFixedSize) self.setLayout(layout) - self.setMinimumHeight(1000) - self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed) + # self.setMinimumHeight(2000) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._update_visibility() def _create_utils_widgets(self, names): @@ -79,14 +76,13 @@ def _create_utils_widgets(self, names): raise RuntimeError( "One or several utility widgets are missing/erroneous" ) - # TODO how to auto-update list based on UTILITIES_WIDGETS ? def _update_visibility(self): widget_class = UTILITIES_WIDGETS[self.utils_choice.currentText()] # print("vis. updated") # print(self.utils_widgets) self._hide_all() - for i, w in enumerate(self.utils_widgets): + for _i, w in enumerate(self.utils_widgets): if isinstance(w, widget_class): w.setVisible(True) w.adjustSize() diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 57c65bac..737b53aa 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -2,24 +2,16 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import List -from typing import Optional +from typing import List, Optional import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_connected, -) -from napari_cellseg3d.code_models.model_instance_seg import ( - binary_watershed, -) -from napari_cellseg3d.code_models.models import ( - model_SegResNet as SegResNet, -) -from napari_cellseg3d.code_models.models import ( - model_SwinUNetR as SwinUNetR, -) +from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod + +# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP +from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet +from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR from napari_cellseg3d.code_models.models import ( model_TRAILMAP_MS as TRAILMAP_MS, ) @@ -40,10 +32,6 @@ # "test" : DO NOT USE, reserved for testing } -INSTANCE_SEGMENTATION_METHOD_LIST = { - "Watershed": binary_watershed, - "Connected components": binary_connected, -} WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") @@ -98,9 +86,7 @@ def get_model(self): @staticmethod def get_model_name_list(): - logger.info( - "Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) - ) + logger.info("Model list :\n" + str(f"{name}\n" for name in MODEL_LIST)) return MODEL_LIST.keys() @@ -130,11 +116,7 @@ class Zoom: @dataclass class InstanceSegConfig: enabled: bool = False - method: str = None - threshold: Thresholding = Thresholding(enabled=False, threshold_value=0.85) - small_object_removal_threshold: Thresholding = Thresholding( - enabled=True, threshold_value=20 - ) + method: InstanceMethod = None @dataclass diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py new file mode 100644 index 00000000..3f95e1a8 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -0,0 +1,433 @@ +import os + +import napari +import numpy as np +import scipy.ndimage as ndimage +from skimage.filters import threshold_otsu +from tifffile import imread, imwrite + +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed + +# import sys +# sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +""" +New code by Yves Paychere +Creates labels of artifacts in an image based on existing labels of neurons +""" + + +def map_labels(labels, artefacts): + """Map the artefacts labels to the neurons labels. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + artefacts : ndarray + Label image with artefacts labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the artefact and the label value of the neurone associated or the neurons associated + new_labels: list + The labels of the artefacts that are not labelled in the neurons + """ + map_labels_existing = [] + new_labels = [] + + for i in np.unique(artefacts): + if i == 0: + continue + indexes = labels[artefacts == i] + # find the most common label in the indexes + unique, counts = np.unique(indexes, return_counts=True) + unique = np.flip(unique[np.argsort(counts)]) + counts = np.flip(counts[np.argsort(counts)]) + if unique[0] != 0: + map_labels_existing.append( + np.array([i, unique[np.argmax(counts)]]) + ) + elif ( + counts[0] < np.sum(counts) * 2 / 3.0 + ): # the artefact is connected to multiple neurons + total = 0 + ii = 1 + while total < np.size(indexes) / 3.0: + total = np.sum(counts[1 : ii + 1]) + ii += 1 + map_labels_existing.append(np.append([i], unique[1 : ii + 1])) + else: + new_labels.append(i) + + return map_labels_existing, new_labels + + +def make_labels( + image, + path_labels_out, + threshold_factor=1, + threshold_size=30, + label_value=1, + do_multi_label=True, + use_watershed=True, + augment_contrast_factor=2, +): + """Detect nucleus. using a binary watershed algorithm and otsu thresholding. + Parameters + ---------- + image : str + image array + path_labels_out : str + Path of the output labelled image. + threshold_size : int, optional + Threshold for nucleus size, if the nucleus is smaller than this value it will be removed. + label_value : int, optional + Value to use for the label image. + do_multi_label : bool, optional + If True, each different nucleus will be labelled as a different value. + use_watershed : bool, optional + If True, use watershed algorithm to detect nucleus. + augment_contrast_factor : int, optional + Factor to augment the contrast of the image. + Returns + ------- + ndarray + Label image with nucleus labelled with 1 value per nucleus. + """ + + image = imread(image) + image = (image - np.min(image)) / (np.max(image) - np.min(image)) + + threshold_brightness = threshold_otsu(image) * threshold_factor + image_contrasted = np.where(image > threshold_brightness, image, 0) + + if use_watershed: + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) + + image_contrasted = image_contrasted * augment_contrast_factor + image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) + labels = binary_watershed(image_contrasted, thres_small=threshold_size) + else: + labels = ndimage.label(image_contrasted)[0] + + labels = select_artefacts_by_size( + labels, min_size=threshold_size, is_labeled=True + ) + + if not do_multi_label: + labels = np.where(labels > 0, label_value, 0) + + imwrite(path_labels_out, labels.astype(np.uint16)) + imwrite( + path_labels_out.replace(".tif", "_contrast.tif"), + image_contrasted.astype(np.float32), + ) + + +def select_image_by_labels(image, labels, path_image_out, label_values): + """Select image by labels. + Parameters + ---------- + image : np.array + image. + labels : np.array + labels. + path_image_out : str + Path of the output image. + label_values : list + List of label values to select. + """ + # image = imread(image) + # labels = imread(labels) + + image = np.where(np.isin(labels, label_values), image, 0) + imwrite(path_image_out, image.astype(np.float32)) + + +# select the smallest cube that contains all the non-zero pixels of a 3d image +def get_bounding_box(img): + height = np.any(img, axis=(0, 1)) + rows = np.any(img, axis=(0, 2)) + cols = np.any(img, axis=(1, 2)) + + xmin, xmax = np.where(cols)[0][[0, -1]] + ymin, ymax = np.where(rows)[0][[0, -1]] + zmin, zmax = np.where(height)[0][[0, -1]] + return xmin, xmax, ymin, ymax, zmin, zmax + + +# crop the image +def crop_image(img): + xmin, xmax, ymin, ymax, zmin, zmax = get_bounding_box(img) + return img[xmin:xmax, ymin:ymax, zmin:zmax] + + +def crop_image_path(image, path_image_out): + """Crop image. + Parameters + ---------- + image : np.array + image + path_image_out : str + Path of the output image. + """ + image = crop_image(image) + imwrite(path_image_out, image.astype(np.float32)) + + +def make_artefact_labels( + image, + labels, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, + label_value=2, + do_multi_label=False, + remove_true_labels=True, +): + """Detect pseudo nucleus. + Parameters + ---------- + image : ndarray + Image. + labels : ndarray + Label image. + threshold_artefact_brightness_percent : int, optional + Threshold for artefact brightness. + threshold_artefact_size_percent : int, optional + Threshold for artefact size, if the artefcact is smaller than this percentage of the neurons it will be removed. + contrast_power : int, optional + Power for contrast enhancement. + label_value : int, optional + Value to use for the label image. + do_multi_label : bool, optional + If True, each different artefact will be labelled as a different value. + remove_true_labels : bool, optional + If True, the true labels will be removed from the artefacts. + Returns + ------- + ndarray + Label image with pseudo nucleus labelled with 1 value per artefact. + """ + + neurons = np.array(labels > 0) + non_neurons = np.array(labels == 0) + + image = (image - np.min(image)) / (np.max(image) - np.min(image)) + + # calculate the percentile of the intensity of all the pixels that are labeled as neurons + # check if the neurons are not empty + if np.sum(neurons) > 0: + threshold = np.percentile( + image[neurons], threshold_artefact_brightness_percent + ) + else: + # take the percentile of the non neurons if the neurons are empty + threshold = np.percentile(image[non_neurons], 90) + + # modify the contrast of the image accoring to the threshold with a tanh function and map the values to [0,1] + + image_contrasted = np.tanh((image - threshold) * contrast_power) + image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( + np.max(image_contrasted) - np.min(image_contrasted) + ) + + artefacts = binary_watershed( + image_contrasted, thres_seeding=0.95, thres_small=15, thres_objects=0.4 + ) + + if remove_true_labels: + # evaluate where the artefacts are connected to the neurons + # map the artefacts label to the neurons label + map_labels_existing, new_labels = map_labels(labels, artefacts) + + # remove the artefacts that are connected to the neurons + for i in map_labels_existing: + artefacts[artefacts == i[0]] = 0 + # remove all the pixels of the neurons from the artefacts + artefacts = np.where(labels > 0, 0, artefacts) + + # remove the artefacts that are too small + # calculate the percentile of the size of the neurons + if np.sum(neurons) > 0: + sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) + neurone_size_percentile = np.percentile( + sizes, threshold_artefact_size_percent + ) + else: + # find the size of each connected component + sizes = ndimage.sum_labels(labels > 0, labels, np.unique(labels)) + # remove the smallest connected components + neurone_size_percentile = np.percentile(sizes, 95) + + # select the artefacts that are bigger than the percentile + + artefacts = select_artefacts_by_size( + artefacts, min_size=neurone_size_percentile, is_labeled=True + ) + + # relabel with the label value if the artefacts are not multi label + if not do_multi_label: + artefacts = np.where(artefacts > 0, label_value, artefacts) + + return artefacts + + +def select_artefacts_by_size(artefacts, min_size, is_labeled=False): + """Select artefacts by size. + Parameters + ---------- + artefacts : ndarray + Label image with artefacts labelled as 1. + min_size : int, optional + Minimum size of artefacts to keep + is_labeled : bool, optional + If True, the artefacts are already labelled. + Returns + ------- + ndarray + Label image with artefacts labelled and small artefacts removed. + """ + if not is_labeled: + # find all the connected components in the artefacts image + labels = ndimage.label(artefacts)[0] + else: + labels = artefacts + + # remove the small components + labels_i, counts = np.unique(labels, return_counts=True) + labels_i = labels_i[counts > min_size] + labels_i = labels_i[labels_i > 0] + artefacts = np.where(np.isin(labels, labels_i), labels, 0) + return artefacts + + +def create_artefact_labels( + image, + labels, + output_path, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, +): + """Create a new label image with artefacts labelled as 2 and neurons labelled as 1. + Parameters + ---------- + image : np.array + image for artefact detection. + labels : np.array + label image array with each neurons labelled as a different int value. + output_path : str + Path to save the output label image file. + threshold_artefact_brightness_percent : int, optional + The artefacts need to be as least as bright as this percentage of the neurone's pixels. + threshold_artefact_size : int, optional + The artefacts need to be at least as big as this percentage of the neurons. + contrast_power : int, optional + Power for contrast enhancement. + """ + artefacts = make_artefact_labels( + image, + labels, + threshold_artefact_brightness_percent, + threshold_artefact_size_percent, + contrast_power=contrast_power, + label_value=2, + do_multi_label=False, + ) + + neurons_artefacts_labels = np.where(labels > 0, 1, artefacts) + imwrite(output_path, neurons_artefacts_labels) + + +def visualize_images(paths): + """Visualize images. + Parameters + ---------- + paths : list + List of images to visualize. + """ + viewer = napari.Viewer(ndisplay=3) + for path in paths: + image = imread(path) + viewer.add_image(image) + # wait for the user to close the viewer + napari.run() + + +def create_artefact_labels_from_folder( + path, + do_visualize=False, + threshold_artefact_brightness_percent=40, + threshold_artefact_size_percent=1, + contrast_power=20, +): + """Create a new label image with artefacts labelled as 2 and neurons labelled as 1 for all images in a folder. The images created are stored in a folder artefact_neurons. + Parameters + ---------- + path : str + Path to folder with images in folder volumes and labels in folder lab_sem. The images are expected to have the same alphabetical order in both folders. + do_visualize : bool, optional + If True, the images will be visualized. + threshold_artefact_brightness_percent : int, optional + The artefacts need to be as least as bright as this percentage of the neurone's pixels. + threshold_artefact_size : int, optional + The artefacts need to be at least as big as this percentage of the neurons. + contrast_power : int, optional + Power for contrast enhancement. + """ + # find all the images in the folder and create a list + path_labels = [ + f for f in os.listdir(path + "/labels") if f.endswith(".tif") + ] + path_images = [ + f for f in os.listdir(path + "/volumes") if f.endswith(".tif") + ] + # sort the list + path_labels.sort() + path_images.sort() + # create the output folder + os.makedirs(path + "/artefact_neurons", exist_ok=True) + # create the artefact labels + for i in range(len(path_images)): + print(path_labels[i]) + # consider that the images and the labels have names in the same alphabetical order + create_artefact_labels( + path + "/volumes/" + path_images[i], + path + "/labels/" + path_labels[i], + path + "/artefact_neurons/" + path_labels[i], + threshold_artefact_brightness_percent, + threshold_artefact_size_percent, + contrast_power, + ) + if do_visualize: + visualize_images( + [ + path + "/volumes/" + path_images[i], + path + "/labels/" + path_labels[i], + path + "/artefact_neurons/" + path_labels[i], + ] + ) + + +# if __name__ == "__main__": +# repo_path = Path(__file__).resolve().parents[1] +# print(f"REPO PATH : {repo_path}") +# paths = [ +# "dataset_clean/cropped_visual/train", +# "dataset_clean/cropped_visual/val", +# "dataset_clean/somatomotor", +# "dataset_clean/visual_tif", +# ] +# for data_path in paths: +# path = str(repo_path / data_path) +# print(path) +# create_artefact_labels_from_folder( +# path, +# do_visualize=False, +# threshold_artefact_brightness_percent=20, +# threshold_artefact_size_percent=1, +# contrast_power=20, +# ) diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py index d772a1c2..641de627 100644 --- a/napari_cellseg3d/dev_scripts/convert.py +++ b/napari_cellseg3d/dev_scripts/convert.py @@ -2,8 +2,7 @@ import os import numpy as np -from tifffile import imread -from tifffile import imwrite +from tifffile import imread, imwrite # input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" # output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" @@ -20,7 +19,6 @@ # print(os.path.basename(filename)) for file in paths: image = imread(file) - # image = img.compute() image[image >= 1] = 1 image = image.astype(np.uint16) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py new file mode 100644 index 00000000..168990e1 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -0,0 +1,369 @@ +import threading +import time +import warnings +from functools import partial +from pathlib import Path + +import napari +import numpy as np +import scipy.ndimage as ndimage +from napari.qt.threading import thread_worker +from tifffile import imread, imwrite +from tqdm import tqdm + +import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed + +# import sys +# sys.path.append(str(Path(__file__) / "../../")) +""" +New code by Yves Paychère +Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold +""" + + +def relabel_non_unique_i(label, save_path, go_fast=False): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + label : np.array + the label image + save_path : str + the path to save the relabeld image + """ + value_label = 0 + new_labels = np.zeros_like(label) + map_labels_existing = [] + unique_label = np.unique(label) + for i_label in tqdm( + range(len(unique_label)), desc="relabeling", ncols=100 + ): + i = unique_label[i_label] + if i == 0: + continue + if go_fast: + new_label, to_add = ndimage.label(label == i) + map_labels_existing.append( + [i, list(range(value_label + 1, value_label + to_add + 1))] + ) + + else: + # catch the warning of the watershed + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + new_label = binary_watershed(label == i) + unique = np.unique(new_label) + to_add = unique[-1] + map_labels_existing.append([i, unique[1:] + value_label]) + + new_label[new_label != 0] += value_label + new_labels += new_label + value_label += to_add + + imwrite(save_path, new_labels) + return map_labels_existing + + +def add_label(old_label, artefact, new_label_path, i_labels_to_add): + """add the label to the label image + Parameters + ---------- + old_label : np.array + the label image + artefact : np.array + the artefact image that contains some neurons + new_label_path : str + the path to save the new label image + """ + new_label = old_label.copy() + max_label = np.max(old_label) + for i, i_label in enumerate(i_labels_to_add): + new_label[artefact == i_label] = i + max_label + 1 + imwrite(new_label_path, new_label) + + +returns = [] + + +def ask_labels(unique_artefact, test=False): + global returns + returns = [] + if not test: + i_labels_to_add_tmp = input( + "Which labels do you want to add (0 to skip) ? (separated by a comma):" + ) + i_labels_to_add_tmp = [int(i) for i in i_labels_to_add_tmp.split(",")] + else: + i_labels_to_add_tmp = [0] + + if i_labels_to_add_tmp == [0]: + print("no label added") + returns = [[]] + print("close the napari window to continue") + return + + for i in i_labels_to_add_tmp: + if i == 0: + print("0 is not a valid label") + # delete the 0 + i_labels_to_add_tmp.remove(i) + # test if all index are negative + if all(i < 0 for i in i_labels_to_add_tmp): + print( + "all labels are negative-> will add all the labels except the one you gave" + ) + i_labels_to_add = list(unique_artefact) + for i in i_labels_to_add_tmp: + if np.abs(i) in i_labels_to_add: + i_labels_to_add.remove(np.abs(i)) + else: + print("the label", np.abs(i), "is not in the label image") + i_labels_to_add_tmp = i_labels_to_add + else: + # remove the negative index + for i in i_labels_to_add_tmp: + if i < 0: + i_labels_to_add_tmp.remove(i) + print( + "ignore the negative label", + i, + " since not all the labels are negative", + ) + if i not in unique_artefact: + print("the label", i, "is not in the label image") + i_labels_to_add_tmp.remove(i) + + returns = [i_labels_to_add_tmp] + print("close the napari window to continue") + + +def relabel( + image_path, + label_path, + go_fast=False, + check_for_unicity=True, + delay=0.3, + viewer=None, + test=False, +): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + image_path : str + the path to the image + label_path : str + the path to the label image + go_fast : bool, optional + if True, the relabeling will be faster but the labels can more frequently be merged, by default False + check_for_unicity : bool, optional + if True, the relabeling will check if the labels are unique, by default True + delay : float, optional + the delay between each image for the visualization, by default 0.3 + viewer : napari.Viewer, optional + the napari viewer, by default None + """ + global returns + + label = imread(label_path) + initial_label_path = label_path + if check_for_unicity: + # check if the label are unique + new_label_path = label_path[:-4] + "_relabel_unique.tif" + map_labels_existing = relabel_non_unique_i( + label, new_label_path, go_fast=go_fast + ) + print( + "visualize the relabeld image in white the previous labels and in red the new labels" + ) + if not test: + visualize_map( + map_labels_existing, label_path, new_label_path, delay=delay + ) + label_path = new_label_path + # detect artefact + print("detection of potential neurons (in progress)") + image = imread(image_path) + artefact = make_artefact_labels.make_artefact_labels( + image, + imread(label_path), + do_multi_label=True, + threshold_artefact_brightness_percent=30, + threshold_artefact_size_percent=0, + contrast_power=30, + ) + print("detection of potential neurons (done)") + # ask the user if the artefact are not neurons + i_labels_to_add = [] + loop = True + unique_artefact = list(np.unique(artefact)) + while loop: + # visualize the artefact and ask the user which label to add to the label image + t = threading.Thread( + target=partial(ask_labels, test=test), args=(unique_artefact,) + ) + t.start() + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add), 0, artefact + ) + if viewer is None: + viewer = napari.view_image(image) + else: + viewer = viewer + viewer.add_image(image, name="image") + viewer.add_labels(artefact_copy, name="potential neurons") + viewer.add_labels(imread(label_path), name="labels") + if not test: + napari.run() + t.join() + i_labels_to_add_tmp = returns[0] + # check if the selected labels are neurones + for i in i_labels_to_add: + if i not in i_labels_to_add_tmp: + i_labels_to_add_tmp.append(i) + artefact_copy = np.where( + np.isin(artefact, i_labels_to_add_tmp), artefact, 0 + ) + print("these labels will be added") + if test: + viewer.close() + viewer = napari.view_image(image) if viewer is None else viewer + if not test: + viewer.add_labels(artefact_copy, name="labels added") + napari.run() + revert = input("Do you want to revert? (y/n)") + if test: + revert = "n" + viewer.close() + if revert != "y": + i_labels_to_add = i_labels_to_add_tmp + for i in i_labels_to_add: + if i in unique_artefact: + unique_artefact.remove(i) + if test: + break + loop = input("Do you want to add more labels? (y/n)") == "y" + # add the label to the label image + new_label_path = initial_label_path[:-4] + "_new_label.tif" + print("the new label will be saved in", new_label_path) + add_label(imread(label_path), artefact, new_label_path, i_labels_to_add) + # store the artefact remaining + new_artefact_path = initial_label_path[:-4] + "_artefact.tif" + artefact = np.where(np.isin(artefact, i_labels_to_add), 0, artefact) + imwrite(new_artefact_path, artefact) + + +def modify_viewer(old_label, new_label, args): + """modify the viewer to show the relabeling + Parameters + ---------- + old_label : napari.layers.Labels + the layer of the old label + new_label : napari.layers.Labels + the layer of the new label + args : list + the first element is the old label and the second element is the new label + """ + if args == "hide new label": + new_label.visible = False + elif args == "show new label": + new_label.visible = True + else: + old_label.selected_label = args[0] + if not np.isnan(args[1]): + new_label.selected_label = args[1] + + +@thread_worker +def to_show(map_labels_existing, delay=0.5): + """modify the viewer to show the relabeling + Parameters + ---------- + map_labels_existing : list + the list of the of the map between the old label and the new label + delay : float, optional + the delay between each image for the visualization, by default 0.3 + """ + time.sleep(2) + for i in map_labels_existing: + yield "hide new label" + if len(i[1]): + yield [i[0], i[1][0]] + else: + yield [i[0], np.nan] + time.sleep(delay) + yield "show new label" + for j in i[1]: + yield [i[0], j] + time.sleep(delay) + + +def create_connected_widget( + old_label, new_label, map_labels_existing, delay=0.5 +): + """Builds a widget that can control a function in another thread.""" + + worker = to_show(map_labels_existing, delay) + worker.start() + worker.yielded.connect( + lambda arg: modify_viewer(old_label, new_label, arg) + ) + + +def visualize_map(map_labels_existing, label_path, relabel_path, delay=0.5): + """visualize the map of the relabeling + Parameters + ---------- + map_labels_existing : list + the list of the relabeling + """ + label = imread(label_path) + relabel = imread(relabel_path) + + viewer = napari.Viewer(ndisplay=3) + + old_label = viewer.add_labels(label, num_colors=3) + new_label = viewer.add_labels(relabel, num_colors=3) + old_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] + ) + new_label.colormap.colors = np.array( + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + ) + + # viewer.dims.ndisplay = 3 + viewer.camera.angles = (180, 3, 50) + viewer.camera.zoom = 1 + + old_label.show_selected_label = True + new_label.show_selected_label = True + + create_connected_widget( + old_label, new_label, map_labels_existing, delay=delay + ) + napari.run() + + +def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): + """relabel the image labelled with different label for each neuron and save it in the save_path location + Parameters + ---------- + folder_path : str + the path to the folder containing the label images + end_of_new_name : str + thename to add at the end of the relabled image + """ + for file in Path.iterdir(folder_path): + if file.suffix == ".tif": + label = imread(str(Path(folder_path / file))) + relabel_non_unique_i( + label, + str(Path(folder_path / file[:-4] + end_of_new_name + ".tif")), + ) + + +# if __name__ == "__main__": +# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") +# image_path = str(im_path / "image.tif") +# gt_labels_path = str(im_path / "labels.tif") +# +# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/drafts.py b/napari_cellseg3d/dev_scripts/drafts.py index adfb7914..cdd02256 100644 --- a/napari_cellseg3d/dev_scripts/drafts.py +++ b/napari_cellseg3d/dev_scripts/drafts.py @@ -1,8 +1,7 @@ import napari import numpy as np from magicgui import magicgui -from napari.types import ImageData -from napari.types import LabelsData +from napari.types import ImageData, LabelsData @magicgui(call_button="Run Threshold") diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py new file mode 100644 index 00000000..bd2f0768 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -0,0 +1,693 @@ +import napari +import numpy as np +import pandas as pd +from tqdm import tqdm + +from napari_cellseg3d.utils import LOGGER as log + +PERCENT_CORRECT = 0.5 # how much of the original label should be found by the model to be classified as correct + + +def evaluate_model_performance( + labels, + model_labels, + threshold_correct=PERCENT_CORRECT, + print_details=False, + visualize=False, +): + """Evaluate the model performance. + Parameters + ---------- + labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + print_details : bool + If True, print the results. + visualize : bool + If True, visualize the results. + Returns + ------- + neuron_found : float + The number of neurons found by the model + neuron_fused: float + The number of neurons fused by the model + neuron_not_found: float + The number of neurons not found by the model + neuron_artefact: float + The number of artefact that the model wrongly labelled as neurons + mean_true_positive_ratio_model: float + The mean (over the model's labels that correspond to one true label) of (correctly labelled pixels)/(total number of pixels of the model's label) + mean_ratio_pixel_found: float + The mean (over the model's labels that correspond to one true label) of (correctly labelled pixels)/(total number of pixels of the true label) + mean_ratio_pixel_found_fused: float + The mean (over the model's labels that correspond to multiple true label) of (correctly labelled pixels)/(total number of pixels of the true label) + mean_true_positive_ratio_model_fused: float + The mean (over the model's labels that correspond to multiple true label) of (correctly labelled pixels in any fused neurons of this model's label)/(total number of pixels of the model's label) + mean_ratio_false_pixel_artefact: float + The mean (over the model's labels that are not labelled in the neurons) of (wrongly labelled pixels)/(total number of pixels of the model's label) + """ + log.debug("Mapping labels...") + map_labels_existing, map_fused_neurons, new_labels = map_labels( + labels, model_labels, threshold_correct + ) + + # calculate the number of neurons individually found + neurons_found = len(map_labels_existing) + # calculate the number of neurons fused + neurons_fused = len(map_fused_neurons) + # calculate the number of neurons not found + log.debug("Calculating the number of neurons not found...") + neurons_found_labels = np.unique( + [i[1] for i in map_labels_existing] + [i[1] for i in map_fused_neurons] + ) + unique_labels = np.unique(labels) + neurons_not_found = len(unique_labels) - 1 - len(neurons_found_labels) + # artefacts found + artefacts_found = len(new_labels) + if len(map_labels_existing) > 0: + # calculate the mean true positive ratio of the model + mean_true_positive_ratio_model = np.mean( + [i[3] for i in map_labels_existing] + ) + # calculate the mean ratio of the neurons pixels correctly labelled + mean_ratio_pixel_found = np.mean([i[2] for i in map_labels_existing]) + else: + mean_true_positive_ratio_model = np.nan + mean_ratio_pixel_found = np.nan + + if len(map_fused_neurons) > 0: + # calculate the mean ratio of the neurons pixels correctly labelled for the fused neurons + mean_ratio_pixel_found_fused = np.mean( + [i[2] for i in map_fused_neurons] + ) + # calculate the mean true positive ratio of the model for the fused neurons + mean_true_positive_ratio_model_fused = np.mean( + [i[3] for i in map_fused_neurons] + ) + else: + mean_ratio_pixel_found_fused = np.nan + mean_true_positive_ratio_model_fused = np.nan + + # calculate the mean false positive ratio of each artefact + if len(new_labels) > 0: + mean_ratio_false_pixel_artefact = np.mean([i[1] for i in new_labels]) + else: + mean_ratio_false_pixel_artefact = np.nan + + log.info( + f"Percent of non-fused neurons found: {neurons_found / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Percent of fused neurons found: {neurons_fused / len(unique_labels) * 100:.2f}%" + ) + log.info( + f"Overall percent of neurons found: {(neurons_found + neurons_fused) / len(unique_labels) * 100:.2f}%" + ) + + if print_details: + log.info(f"Neurons found: {neurons_found}") + log.info(f"Neurons fused: {neurons_fused}") + log.info(f"Neurons not found: {neurons_not_found}") + log.info(f"Artefacts found: {artefacts_found}") + log.info( + f"Mean true positive ratio of the model: {mean_true_positive_ratio_model}" + ) + log.info( + f"Mean ratio of the neurons pixels correctly labelled: {mean_ratio_pixel_found}" + ) + log.info( + f"Mean ratio of the neurons pixels correctly labelled for fused neurons: {mean_ratio_pixel_found_fused}" + ) + log.info( + f"Mean true positive ratio of the model for fused neurons: {mean_true_positive_ratio_model_fused}" + ) + log.info( + f"Mean ratio of the false pixels labelled as neurons: {mean_ratio_false_pixel_artefact}" + ) + + if visualize: + viewer = napari.Viewer() + viewer.add_labels(labels, name="ground truth") + viewer.add_labels(model_labels, name="model's labels") + found_model = np.where( + np.isin(model_labels, [i[0] for i in map_labels_existing]), + model_labels, + 0, + ) + viewer.add_labels(found_model, name="model's labels found") + found_label = np.where( + np.isin(labels, [i[1] for i in map_labels_existing]), labels, 0 + ) + viewer.add_labels(found_label, name="ground truth found") + neurones_not_found_labels = np.where( + np.isin(unique_labels, neurons_found_labels) is False, + unique_labels, + 0, + ) + neurones_not_found_labels = neurones_not_found_labels[ + neurones_not_found_labels != 0 + ] + not_found = np.where( + np.isin(labels, neurones_not_found_labels), labels, 0 + ) + viewer.add_labels(not_found, name="ground truth not found") + artefacts_found = np.where( + np.isin(model_labels, [i[0] for i in new_labels]), + model_labels, + 0, + ) + viewer.add_labels(artefacts_found, name="model's labels artefacts") + fused_model = np.where( + np.isin(model_labels, [i[0] for i in map_fused_neurons]), + model_labels, + 0, + ) + viewer.add_labels(fused_model, name="model's labels fused") + fused_label = np.where( + np.isin(labels, [i[1] for i in map_fused_neurons]), labels, 0 + ) + viewer.add_labels(fused_label, name="ground truth fused") + napari.run() + + return ( + neurons_found, + neurons_fused, + neurons_not_found, + artefacts_found, + mean_true_positive_ratio_model, + mean_ratio_pixel_found, + mean_ratio_pixel_found_fused, + mean_true_positive_ratio_model_fused, + mean_ratio_false_pixel_artefact, + ) + + +def map_labels(gt_labels, model_labels, threshold_correct=PERCENT_CORRECT): + """Map the model's labels to the neurons labels. + Parameters + ---------- + gt_labels : ndarray + Label image with neurons labelled as mulitple values. + model_labels : ndarray + Label image from the model labelled as mulitple values. + Returns + ------- + map_labels_existing: numpy array + The label value of the model and the label value of the neuron associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled + map_fused_neurons: numpy array + The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones + new_labels: list + The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact + """ + map_labels_existing = [] + map_fused_neurons = [] + new_labels = [] + + for i in tqdm(np.unique(model_labels)): + if i == 0: + continue + indexes = gt_labels[model_labels == i] + # find the most common labels in the label i of the model + unique, counts = np.unique(indexes, return_counts=True) + tmp_map = [] + total_pixel_found = 0 + + # log.debug(f"i: {i}") + for ii in range(len(unique)): + true_positive_ratio_model = counts[ii] / np.sum(counts) + # if >50% of the pixels of the label i of the model correspond to the background it is considered as an artifact, that should not have been found + # log.debug(f"unique: {unique[ii]}") + if unique[ii] == 0: + if true_positive_ratio_model > threshold_correct: + # -> artifact found + new_labels.append([i, true_positive_ratio_model]) + else: + # if >50% of the pixels of the label unique[ii] of the true label map to the same label i of the model, + # the label i is considered either as a fused neurons, if it the case for multiple unique[ii] or as neurone found + ratio_pixel_found = counts[ii] / np.sum( + gt_labels == unique[ii] + ) + if ratio_pixel_found > threshold_correct: + total_pixel_found += np.sum(counts[ii]) + tmp_map.append( + [ + i, + unique[ii], + ratio_pixel_found, + true_positive_ratio_model, + ] + ) + + if len(tmp_map) == 1: + # map to only one true neuron -> found neuron + map_labels_existing.append(tmp_map[0]) + elif len(tmp_map) > 1: + # map to multiple true neurons -> fused neuron + for ii in range(len(tmp_map)): + if total_pixel_found > np.sum(counts): + raise ValueError( + f"total_pixel_found > np.sum(counts) : {total_pixel_found} > {np.sum(counts)}" + ) + tmp_map[ii][3] = total_pixel_found / np.sum(counts) + map_fused_neurons += tmp_map + + # log.debug(f"map_labels_existing: {map_labels_existing}") + # log.debug(f"map_fused_neurons: {map_fused_neurons}") + # log.debug(f"new_labels: {new_labels}") + return map_labels_existing, map_fused_neurons, new_labels + + +def save_as_csv(results, path): + """ + Save the results as a csv file + + Parameters + ---------- + results: list + The results of the evaluation + path: str + The path to save the csv file + """ + log.debug(np.array(results).shape) + df = pd.DataFrame( + [results], + columns=[ + "neurons_found", + "neurons_fused", + "neurons_not_found", + "artefacts_found", + "mean_true_positive_ratio_model", + "mean_ratio_pixel_found", + "mean_ratio_pixel_found_fused", + "mean_true_positive_ratio_model_fused", + "mean_ratio_false_pixel_artefact", + ], + ) + df.to_csv(path, index=False) + + +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons + +####################### +# Slower version that was used for debugging +####################### + +# from collections import Counter +# from dataclasses import dataclass +# from typing import Dict +# @dataclass +# class LabelInfo: +# gt_index: int +# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) +# best_model_label_coverage: float = ( +# 0.0 # ratio of pixels of the gt label correctly labelled +# ) +# overall_gt_label_coverage: float = 0.0 # true positive ration of the model +# +# def get_correct_ratio(self): +# for model_label, status in self.model_labels_id_and_status.items(): +# if status == "correct": +# return self.best_model_label_coverage +# else: +# return None + + +# def eval_model(gt_labels, model_labels, print_report=False): +# +# report_list, new_labels, fused_labels = create_label_report( +# gt_labels, model_labels +# ) +# per_label_perfs = [] +# for report in report_list: +# if print_report: +# log.info( +# f"Label {report.gt_index} : {report.model_labels_id_and_status}" +# ) +# log.info( +# f"Best model label coverage : {report.best_model_label_coverage}" +# ) +# log.info( +# f"Overall gt label coverage : {report.overall_gt_label_coverage}" +# ) +# +# perf = report.get_correct_ratio() +# if perf is not None: +# per_label_perfs.append(perf) +# +# per_label_perfs = np.array(per_label_perfs) +# return per_label_perfs.mean(), new_labels, fused_labels + + +# def create_label_report(gt_labels, model_labels): +# """Map the model's labels to the neurons labels. +# Parameters +# ---------- +# gt_labels : ndarray +# Label image with neurons labelled as mulitple values. +# model_labels : ndarray +# Label image from the model labelled as mulitple values. +# Returns +# ------- +# map_labels_existing: numpy array +# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled +# map_fused_neurons: numpy array +# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones +# new_labels: list +# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact +# """ +# +# map_labels_existing = [] +# map_fused_neurons = {} +# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" +# background_labels = model_labels[np.where((gt_labels == 0))] +# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" +# new_labels = [] +# for lab in np.unique(background_labels): +# if lab == 0: +# continue +# gt_background_size_at_lab = ( +# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] +# .flatten() +# .shape[0] +# ) +# gt_lab_size = ( +# gt_labels[np.where(model_labels == lab)].flatten().shape[0] +# ) +# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: +# new_labels.append(lab) +# +# label_report_list = [] +# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label +# # model_label_values = {} # contains the model labels value assigned to each unique gt label +# not_found_id = 0 +# +# for i in tqdm(np.unique(gt_labels)): +# if i == 0: +# continue +# +# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label +# +# model_lab_on_gt = model_labels[ +# np.where(((gt_labels == i) & (model_labels != 0))) +# ] # all models labels on single gt_label +# info = LabelInfo(i) +# +# info.model_labels_id_and_status = { +# label_id: "" for label_id in np.unique(model_lab_on_gt) +# } +# +# if model_lab_on_gt.shape[0] == 0: +# info.model_labels_id_and_status[ +# f"not_found_{not_found_id}" +# ] = "not found" +# not_found_id += 1 +# label_report_list.append(info) +# continue +# +# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") +# +# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label +# log.debug( +# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" +# ) +# +# ratio = [] +# for model_lab_id in info.model_labels_id_and_status.keys(): +# size_model_label = ( +# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] +# .flatten() +# .shape[0] +# ) +# size_gt_label = gt_label.flatten().shape[0] +# +# log.debug(f"size_model_label : {size_model_label}") +# log.debug(f"size_gt_label : {size_gt_label}") +# +# ratio.append(size_model_label / size_gt_label) +# +# # log.debug(ratio) +# ratio_model_lab_for_given_gt_lab = np.array(ratio) +# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() +# +# best_model_lab_id = model_lab_on_gt[ +# np.argmax(ratio_model_lab_for_given_gt_lab) +# ] +# log.debug(f"best_model_lab_id : {best_model_lab_id}") +# +# info.overall_gt_label_coverage = ( +# ratio_model_lab_for_given_gt_lab.sum() +# ) # the ratio of the pixels of the true label correctly labelled +# +# if info.best_model_label_coverage > PERCENT_CORRECT: +# info.model_labels_id_and_status[best_model_lab_id] = "correct" +# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] +# else: +# info.model_labels_id_and_status[best_model_lab_id] = "wrong" +# for model_lab_id in np.unique(model_lab_on_gt): +# if model_lab_id != best_model_lab_id: +# log.debug(model_lab_id, "is wrong") +# info.model_labels_id_and_status[model_lab_id] = "wrong" +# +# label_report_list.append(info) +# +# correct_labels_id = [] +# for report in label_report_list: +# for i_lab in report.model_labels_id_and_status.keys(): +# if report.model_labels_id_and_status[i_lab] == "correct": +# correct_labels_id.append(i_lab) +# """Find all labels in label_report_list that are correct more than once""" +# duplicated_labels = [ +# item for item, count in Counter(correct_labels_id).items() if count > 1 +# ] +# "Sum up the size of all duplicated labels" +# for i in duplicated_labels: +# for report in label_report_list: +# if ( +# i in report.model_labels_id_and_status.keys() +# and report.model_labels_id_and_status[i] == "correct" +# ): +# size = ( +# model_labels[np.where(model_labels == i)] +# .flatten() +# .shape[0] +# ) +# map_fused_neurons[i] = size +# +# return label_report_list, new_labels, map_fused_neurons + +# if __name__ == "__main__": +# """ +# # Example of how to use the functions in this module. +# a = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +# +# b = np.array([[5, 5, 0, 0], [5, 5, 2, 0], [0, 2, 2, 0], [0, 0, 2, 0]]) +# evaluate_model_performance(a, b) +# +# c = np.array([[2, 2, 0, 0], [2, 2, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +# +# d = np.array([[4, 0, 4, 0], [4, 4, 4, 0], [0, 4, 4, 0], [0, 0, 4, 0]]) +# +# evaluate_model_performance(c, d) +# +# from tifffile import imread +# labels=imread("dataset/visual_tif/labels/testing_im_new_label.tif") +# labels_model=imread("dataset/visual_tif/artefact_neurones/basic_model.tif") +# evaluate_model_performance(labels, labels_model,visualize=True) +# """ +# from tifffile import imread +# +# labels = imread("dataset_clean/VALIDATION/validation_labels.tif") +# try: +# labels_model = imread("results/watershed_based_model/instance_labels.tif") +# except: +# raise Exception( +# "you should download the model's label that are under results (output and statistics)/watershed_based_model/instance_labels.tif and put it in the folder results/watershed_based_model/" +# ) +# +# evaluate_model_performance(labels, labels_model, visualize=True) diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index 998645cb..20668125 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -3,13 +3,15 @@ import napari import numpy as np from napari.qt.threading import thread_worker -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QProgressBar -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import ( + QGridLayout, + QLabel, + QProgressBar, + QPushButton, + QTextEdit, + QVBoxLayout, + QWidget, +) @thread_worker diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index d23199ee..276f9214 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,41 +1,37 @@ import threading import warnings from functools import partial -from typing import List -from typing import Optional +from typing import List, Optional import napari # Qt -from qtpy import QtCore - # from qtpy.QtCore import QtWarningMsg -from qtpy.QtCore import QObject -from qtpy.QtCore import Qt -from qtpy.QtCore import QUrl -from qtpy.QtGui import QCursor -from qtpy.QtGui import QDesktopServices -from qtpy.QtGui import QTextCursor -from qtpy.QtWidgets import QCheckBox -from qtpy.QtWidgets import QComboBox -from qtpy.QtWidgets import QDoubleSpinBox -from qtpy.QtWidgets import QFileDialog -from qtpy.QtWidgets import QGridLayout -from qtpy.QtWidgets import QGroupBox -from qtpy.QtWidgets import QHBoxLayout -from qtpy.QtWidgets import QLabel -from qtpy.QtWidgets import QLayout -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QMenu -from qtpy.QtWidgets import QPushButton -from qtpy.QtWidgets import QRadioButton -from qtpy.QtWidgets import QScrollArea -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QSlider -from qtpy.QtWidgets import QSpinBox -from qtpy.QtWidgets import QTextEdit -from qtpy.QtWidgets import QVBoxLayout -from qtpy.QtWidgets import QWidget +from qtpy import QtCore +from qtpy.QtCore import QObject, Qt, QUrl +from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor +from qtpy.QtWidgets import ( + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGridLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLayout, + QLineEdit, + QMenu, + QPushButton, + QRadioButton, + QScrollArea, + QSizePolicy, + QSlider, + QSpinBox, + QTextEdit, + QVBoxLayout, + QWidget, +) # Local from napari_cellseg3d import utils @@ -189,7 +185,7 @@ def show_utils_menu(self, widget, event): menu.setStyleSheet(f"background-color: {napari_grey}; color: white;") actions = [] - for title in UTILITIES_WIDGETS.keys(): + for title in UTILITIES_WIDGETS: a = menu.addAction(f"Utilities : {title}") actions.append(a) @@ -499,9 +495,12 @@ def __init__( self._build_container() - def _build_container(self): - self.container.layout + def set_visibility(self, visible: bool): + self.container.setVisible(visible) + self.setVisible(visible) + self.text_label.setVisible(visible) + def _build_container(self): if self.text_label is not None: add_widgets( self.container.layout, @@ -771,7 +770,7 @@ def layer_name(self): def layer_data(self): if self.layer_list.count() < 1: warnings.warn("Please select a valid layer !") - return + return None return self._viewer.layers[self.layer_name()].data @@ -1009,7 +1008,7 @@ def make_n_spinboxes( raise ValueError("Cannot make less than 2 spin boxes") boxes = [] - for i in range(n): + for _i in range(n): box = class_(min, max, default, step, parent, fixed) boxes.append(box) return boxes @@ -1021,7 +1020,7 @@ class DoubleIncrementCounter(QDoubleSpinBox): def __init__( self, lower: Optional[float] = 0.0, - upper: Optional[float] = 10.0, + upper: Optional[float] = 1000.0, default: Optional[float] = 0.0, step: Optional[float] = 1.0, parent: Optional[QWidget] = None, @@ -1045,6 +1044,13 @@ def __init__( if label is not None: self.label = make_label(name=label) + self.valueChanged.connect(self._update_step) + + def _update_step(self): # FIXME check divide_factor + if self.value() < 0.9: + self.setSingleStep(0.01) + else: + self.setSingleStep(0.1) @property def tooltips(self): @@ -1081,6 +1087,10 @@ def make_n( cls, n, lower, upper, default, step, parent, fixed ) + def set_visibility(self, visible: bool): + self.setVisible(visible) + self.label.setVisible(visible) + class IntIncrementCounter(QSpinBox): """Class implementing a number counter with increments (spin box) for int.""" diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 6a3b57d3..a52c3de9 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -4,8 +4,6 @@ from pathlib import Path import numpy as np -from pandas import DataFrame -from pandas import Series from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread @@ -16,7 +14,6 @@ # LOGGER.setLevel(logging.DEBUG) LOGGER.setLevel(logging.INFO) ############### - """ utils.py ==================================== @@ -35,9 +32,7 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) + cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls] @@ -133,7 +128,7 @@ def resize(image, zoom_factors): mode="nearest-exact", padding_mode="empty", )(np.expand_dims(image, axis=0)) - return isotropic_image[0] + return isotropic_image[0].numpy() def align_array_sizes(array_shape, target_shape): @@ -141,9 +136,8 @@ def align_array_sizes(array_shape, target_shape): for i in range(len(target_shape)): if target_shape[i] != array_shape[i]: for j in range(len(array_shape)): - if array_shape[i] == target_shape[j]: - if j != i: - index_differences.append({"origin": i, "target": j}) + if array_shape[i] == target_shape[j] and j != i: + index_differences.append({"origin": i, "target": j}) # print(index_differences) if len(index_differences) == 0: @@ -279,51 +273,51 @@ def annotation_to_input(label_ermito): return anno -def check_csv(project_path, ext): - if not Path(Path(project_path) / Path(project_path).name).is_file(): - cols = [ - "project", - "type", - "ext", - "z", - "y", - "x", - "z_size", - "y_size", - "x_size", - "created_date", - "update_date", - "path", - "notes", - ] - df = DataFrame(index=[], columns=cols) - filename_pattern_original = Path(project_path) / Path( - f"dataset/Original_size/Original/*{ext}" - ) - images_original = tfl_imread(filename_pattern_original) - z, y, x = images_original.shape - record = Series( - [ - Path(project_path).name, - "dataset", - ".tif", - 0, - 0, - 0, - z, - y, - x, - datetime.datetime.now(), - "", - Path(project_path) / Path("dataset/Original_size/Original"), - "", - ], - index=df.columns, - ) - df = df.append(record, ignore_index=True) - df.to_csv(Path(project_path) / Path(project_path).name) - else: - pass +# def check_csv(project_path, ext): +# if not Path(Path(project_path) / Path(project_path).name).is_file(): +# cols = [ +# "project", +# "type", +# "ext", +# "z", +# "y", +# "x", +# "z_size", +# "y_size", +# "x_size", +# "created_date", +# "update_date", +# "path", +# "notes", +# ] +# df = DataFrame(index=[], columns=cols) +# filename_pattern_original = Path(project_path) / Path( +# f"dataset/Original_size/Original/*{ext}" +# ) +# images_original = dask_imread(filename_pattern_original) +# z, y, x = images_original.shape +# record = Series( +# [ +# Path(project_path).name, +# "dataset", +# ".tif", +# 0, +# 0, +# 0, +# z, +# y, +# x, +# datetime.datetime.now(), +# "", +# Path(project_path) / Path("dataset/Original_size/Original"), +# "", +# ], +# index=df.columns, +# ) +# df = df.append(record, ignore_index=True) +# df.to_csv(Path(project_path) / Path(project_path).name) +# else: +# pass # def check_annotations_dir(project_path): @@ -358,9 +352,10 @@ def fill_list_in_between(lst, n, elem): new_list += temp_list else: new_list.append(lst[i]) - for j in range(n): + for _j in range(n): new_list.append(elem) return new_list + return None # def check_zarr(project_path, ext): @@ -412,17 +407,17 @@ def parse_default_path(possible_paths): def get_date_time(): """Get date and time in the following format : year_month_day_hour_minute_second""" - return "{:%Y_%m_%d_%H_%M_%S}".format(datetime.now()) + return f"{datetime.now():%Y_%m_%d_%H_%M_%S}" def get_time(): """Get time in the following format : hour:minute:second. NOT COMPATIBLE with file paths (saving with ":" is invalid)""" - return "{:%H:%M:%S}".format(datetime.now()) + return f"{datetime.now():%H:%M:%S}" def get_time_filepath(): """Get time in the following format : hour_minute_second. Compatible with saving""" - return "{:%H_%M_%S}".format(datetime.now()) + return f"{datetime.now():%H_%M_%S}" def load_images( @@ -466,6 +461,7 @@ def load_images( LOGGER.error( "Loading a stack this way is no longer supported. Use napari to load a stack." ) + else: images_original = tfl_imread( filename_pattern_original @@ -486,12 +482,12 @@ def load_images( # return base_label -def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): - images_label = load_images(mod_mask_dir, filetype, as_folder) - if as_folder: - images_label = images_label.compute() - base_label = images_label - return base_label +# def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): +# images_label = load_images(mod_mask_dir, filetype, as_folder) +# if as_folder: +# images_label = images_label.compute() +# base_label = images_label +# return base_label def save_stack(images, out_path, filetype=".png", check_warnings=False): diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb new file mode 100644 index 00000000..b8810301 --- /dev/null +++ b/notebooks/assess_instance.ipynb @@ -0,0 +1,503 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import napari\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from tifffile import imread\n", + "\n", + "from napari_cellseg3d.dev_scripts import evaluate_labels as eval\n", + "from napari_cellseg3d.utils import resize\n", + "from napari_cellseg3d.code_models.model_instance_seg import (\n", + " binary_connected,\n", + " binary_watershed,\n", + " voronoi_otsu,\n", + " to_semantic,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "viewer = napari.Viewer()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "prediction_path = str(im_path / \"pred.tif\")\n", + "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", + "\n", + "prediction = imread(prediction_path)\n", + "gt_labels = imread(gt_labels_path)\n", + "\n", + "zoom = (1 / 5, 1, 1)\n", + "prediction_resized = resize(prediction, zoom)\n", + "gt_labels_resized = resize(gt_labels, zoom)\n", + "\n", + "\n", + "viewer.add_image(prediction_resized, name=\"pred\", colormap=\"inferno\")\n", + "viewer.add_labels(gt_labels_resized, name=\"gt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5817600487210719" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from napari_cellseg3d.utils import dice_coeff\n", + "\n", + "dice_coeff(\n", + " to_semantic(gt_labels_resized.copy()),\n", + " to_semantic(prediction_resized.copy()),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.dev_scripts.correct_labels import relabel\n", + "\n", + "# gt_corrected = relabel(prediction_path, gt_labels_path, go_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, gt_labels_resized)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(25, 64, 64)\n", + "(25, 64, 64)\n", + "125\n" + ] + } + ], + "source": [ + "print(prediction_resized.shape)\n", + "print(gt_labels_resized.shape)\n", + "print(np.unique(gt_labels_resized).shape[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "connected = binary_connected(prediction_resized, thres_small=2)\n", + "viewer.add_labels(connected, name=\"connected\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 15:48:47,057 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", + "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(65,\n", + " 46,\n", + " 13,\n", + " 12,\n", + " 0.9042297461803984,\n", + " 0.8512759824829847,\n", + " 0.9136359067720888,\n", + " 0.8728146835389444,\n", + " 1.0)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, connected)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 15:48:47,168 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", + "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(68,\n", + " 43,\n", + " 13,\n", + " 10,\n", + " 0.8856947654346812,\n", + " 0.8747475859219296,\n", + " 0.9187750563205743,\n", + " 0.862012598981557,\n", + " 1.0)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "watershed = binary_watershed(\n", + " prediction_resized, thres_small=2, rem_seed_thres=1\n", + ")\n", + "viewer.add_labels(watershed)\n", + "eval.evaluate_model_performance(gt_labels_resized, watershed)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(25, 64, 64)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", + "\n", + "from skimage.morphology import remove_small_objects\n", + "\n", + "voronoi = remove_small_objects(voronoi, 2)\n", + "viewer.add_labels(voronoi)\n", + "voronoi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "dtype('int64')" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gt_labels_resized.dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(voronoi, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# np.unique(gt_labels_resized, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 15:48:47,570 - Mapping labels...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", + "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", + "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", + "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "(99,\n", + " 12,\n", + " 13,\n", + " 17,\n", + " 0.6286692001809993,\n", + " 0.9378875115172982,\n", + " 0.949109422876503,\n", + " 0.5827007113964422,\n", + " 0.7306099091287442)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/csv_cell_plot.ipynb b/notebooks/csv_cell_plot.ipynb index 8b14fb8d..e00a9f1c 100644 --- a/notebooks/csv_cell_plot.ipynb +++ b/notebooks/csv_cell_plot.ipynb @@ -58,7 +58,6 @@ "outputs": [], "source": [ "def plot_data(data_path, x_inv=False, y_inv=False, z_inv=False):\n", - "\n", " data = pd.read_csv(data_path, index_col=False)\n", "\n", " x = data[\"Centroid x\"]\n", @@ -185,7 +184,6 @@ "outputs": [], "source": [ "def plotly_cells_stats(data):\n", - "\n", " init_notebook_mode() # initiate notebook for offline plot\n", "\n", " x = data[\"Centroid x\"]\n", diff --git a/pyproject.toml b/pyproject.toml index 11b8dced..9c5adda7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "nibabel", "scikit-image", "pillow", + "pyclesperanto-prototype", "tqdm", "matplotlib", "vispy>=0.9.6", @@ -41,8 +42,47 @@ where = ["."] "*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] [tool.ruff] -# Never enforce `E501` (line length violations). -ignore = ["E501", "E741"] +select = [ + "E", "F", "W", + "A", + "B", + "G", + "I", + "PT", + "PTH", + "RET", + "SIM", + "TCH", + "NPY", +] +# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) +# and 'G004' (do not use f-strings in logging) +ignore = ["E501", "E741", "G004", "A003"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "docs/conf.py", + "napari_cellseg3d/_tests/conftest.py", +] [tool.black] line-length = 79 diff --git a/requirements.txt b/requirements.txt index f97de33c..3189e9c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ black coverage isort +itk pytest pytest-qt sphinx @@ -12,11 +13,15 @@ numpy napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 +pre-commit +pyclesperanto-prototype>=0.22.0 +pysqlite3 dask-image>=0.6.0 matplotlib>=3.4.1 tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 -monai[nibabel,scikit-image,einops]>=0.9.0 +monai[nibabel,einops]>=1.0.1 pillow +scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/tox.ini b/tox.ini index 292b8fa4..87338cd8 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,6 @@ deps = magicgui pytest-qt qtpy -; opencv-python +; pyopencl[pocl] commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml