From 634f5a4823f5d4d4cbb85b42d755dd1ea15931e7 Mon Sep 17 00:00:00 2001 From: Cyril Achard <94955160+C-Achard@users.noreply.github.com> Date: Wed, 12 Jul 2023 09:26:50 +0200 Subject: [PATCH] WNet + models code refactor (#36) * Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black * black * Complete instance method evaluation * Enfore pre-commit style * Removing dask-image * Fixed erroneous dtype conversion * Update test_plugin_utils.py * Update tox.ini * Added new pre-commit hooks * Run full suite of pre-commit hooks * Enforce style * Documentation update, crop contrast fix * Updated hooks * Update setup.cfg * Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black * Enfore pre-commit style * Updated project files * Removing dask-image * Latest pre-commit hooks * Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling * isort * Fix inference * Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black * Added pre-commit hooks * Update .pre-commit-config.yaml * Update pyproject.toml * Update pyproject.toml Ruff config * Enfore pre-commit style * Update .gitignore * Version bump * Revert "Version bump" This reverts commit 6e39971b39fb926084f3ed71d82e8c25f68f8b6f. * Updated project files * Fixed wrong value in instance sliders * Removing dask-image * Update test_plugin_utils.py * Relabeling tests * Added new pre-commit hooks * Latest pre-commit hooks * Run full suite of pre-commit hooks * Model class refactor * Added LR scheduler in training - Added ReduceLROnPlateau with params in training - Updated training guide - Minor UI attribute refactor - black * Update assess_instance.ipynb * Update .gitignore * Started adding WNet * Specify no grad in inference * First functional WNet inference, no CRF * Create test_models.py * Run full suite of pre-commit hooks * Patch for tests action + style * Add softNCuts basic test * Added crf Co-Authored-By: Nevexios <72894299+nevexios@users.noreply.github.com> * More pre-commit checks * Functional CRF * Fix erroneous test comment, added toggle for crf - Warn if crf not installed - Fix test * Specify missing test deps * Trying to fix deps on Git * Removed master link to pydensecrf * Use commit hash * Removed commit hash * Removed master * Update tox.ini * Update pyproject.toml * Fixes and improvements - More CRF info - Added error handling to scheduler rate - Added ETA to training - Updated padding warning trigger size * Fixes and channel labeling prototype * Fixes - Fixed multi-channel instance and csv stats - Fixed rotation of inference outputs - Raised max crop size * Update plugin_model_inference.py * Update plugin_crop.py * Fixed patch_func sample number mismatch * Testing relabel tools * Fixes in inference * add model template + fix test + wnet loading opti - test fixes - changed crf input reqs - adapted instance seg for several channels * Update model_WNet.py * Update model_VNet.py * Fixed folder creation when saving to folder * Fix check_ready for results filewidget * Added remapping in WNet + ruff config * Run new hooks * Small docs update * Testing fix * Fixed multithread testing (locally) * Added proper tests for train/infer * Slight coverage increase * Update test_plugin_inference.py * Set window inference to 64 for WNet * Update instance_segmentation.py * Moved normalization to the correct place * Added auto-set dims for cropping * Update test_plugin_utils.py * More WNet - Added experimental .pt loading for jit models - More CRF tests - Optimized WNet by loading inference only * Update crf test/deps for testing * Update test_and_deploy.yml * Update test_and_deploy.yml * Update tox.ini * Update test_and_deploy.yml * Trying to fix tox install of pydensecrf * Added experimental ONNX support for inference * Updated WNet for ONNX conversion * Added dropout param * Minor fixes in training * Fix weights file extension in inference + coverage - Remove unused scripts - More tests - Fixed weights type in inference * Run all hooks * Fix inference testing * Changed anisotropy calculation * Finish rebase + bump version * Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling * Disabled small removal in Voronoi-Otsu * Added new docs for instance seg * Docs + UI update - Updated welcome/README - Changed step for DoubleCounter * Update requirements.txt Fix typo * isort * Fix tests * Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init * Fix inference * Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black * black * Complete instance method evaluation * Added pre-commit hooks * Enfore pre-commit style * Update .gitignore * Version bump * Updated project files * Fixed missing parent error * Fixed wrong value in instance sliders * Removing dask-image * Fixed erroneous dtype conversion * Update test_plugin_utils.py * Temporary test action patch * Update plugin_convert.py * Update tox.ini Added pocl for testing on GH Actions * Update tox.ini * Found existing pocl * Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime * Relabeling tests * Latest pre-commit hooks * Run full suite of pre-commit hooks * Enforce style * Instance segmentation refactor + Voronoi-Otsu - Improved code for instance segmentation - Added Voronoi-Otsu labeling from pyclesperanto TODO : credits for labeling * Disabled small removal in Voronoi-Otsu * Added new docs for instance seg * Docs + UI update - Updated welcome/README - Changed step for DoubleCounter * Update requirements.txt Fix typo * isort * Fix tests * Fixed parental issues and instance seg widget init - Fixed widgets parents that were incorrectly init - Improve use of instance seg. method classes and init * Fix inference * Added labeling tools + UI tweaks - Added tools from MLCourse to evaluate labels and auto-correct them - Instance seg benchmark notebook - Tweaked utils UI to scale according to Viewer size Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Testing instance methods Co-Authored-By: gityves <114951621+gityves@users.noreply.github.com> * Many fixes - Fixed monai reqs - Added custom functions for label checking - Fixed return type of voronoi_otsu and utils.resize - black * black * Complete instance method evaluation * Added pre-commit hooks * Update .pre-commit-config.yaml * Update pyproject.toml * Update pyproject.toml Ruff config * Enfore pre-commit style * Update .gitignore * Version bump * Revert "Version bump" This reverts commit 6e39971b39fb926084f3ed71d82e8c25f68f8b6f. * Updated project files * Fixed missing parent error * Fixed wrong value in instance sliders * Removing dask-image * Fixed erroneous dtype conversion * Update test_plugin_utils.py * Update plugin_convert.py * Update tox.ini Added pocl for testing on GH Actions * Update tox.ini * Found existing pocl * Updated utils test to avoid Voronoi-Otsu VO is missing CL runtime * Relabeling tests * Added new pre-commit hooks * Latest pre-commit hooks * Run full suite of pre-commit hooks * Model class refactor * Added LR scheduler in training - Added ReduceLROnPlateau with params in training - Updated training guide - Minor UI attribute refactor - black * Update assess_instance.ipynb * Update .gitignore * Started adding WNet * Specify no grad in inference * First functional WNet inference, no CRF * Create test_models.py * Run full suite of pre-commit hooks * Patch for tests action + style * Add softNCuts basic test * Added crf Co-Authored-By: Nevexios <72894299+nevexios@users.noreply.github.com> * More pre-commit checks * Functional CRF * Fix erroneous test comment, added toggle for crf - Warn if crf not installed - Fix test * Specify missing test deps * Trying to fix deps on Git * Removed master link to pydensecrf * Use commit hash * Removed commit hash * Removed master * Update tox.ini * Update pyproject.toml * Fixes and improvements - More CRF info - Added error handling to scheduler rate - Added ETA to training - Updated padding warning trigger size * Fixes and channel labeling prototype * Fixes - Fixed multi-channel instance and csv stats - Fixed rotation of inference outputs - Raised max crop size * Update plugin_model_inference.py * Fixed patch_func sample number mismatch * Testing relabel tools * Fixes in inference * add model template + fix test + wnet loading opti - test fixes - changed crf input reqs - adapted instance seg for several channels * Update model_WNet.py * Update model_VNet.py * Fixed folder creation when saving to folder * Fix check_ready for results filewidget * Added remapping in WNet + ruff config * Run new hooks * Small docs update * Testing fix * Fixed multithread testing (locally) * Added proper tests for train/infer * Slight coverage increase * Update test_plugin_inference.py * Set window inference to 64 for WNet * Moved normalization to the correct place * Added auto-set dims for cropping * Update test_plugin_utils.py * More WNet - Added experimental .pt loading for jit models - More CRF tests - Optimized WNet by loading inference only * Update crf test/deps for testing * Update test_and_deploy.yml * Update test_and_deploy.yml * Update tox.ini * Update test_and_deploy.yml * Trying to fix tox install of pydensecrf * Added experimental ONNX support for inference * Updated WNet for ONNX conversion * Added dropout param * Minor fixes in training * Fix weights file extension in inference + coverage - Remove unused scripts - More tests - Fixed weights type in inference * Run all hooks * Fix inference testing * Changed anisotropy calculation * Fixed aniso correction and CRF interaction * Remove duplicate tests * Finish rebase + changed step to auto in spinbox * Updated based on feedback from CYHSM Co-Authored-By: Markus Frey <5563464+CYHSM@users.noreply.github.com> * Added minimal WNet notebook for training * Remove dask * WNet model docs * Added QoL shape info for layer selecter * WNet fixes + PR feedback improvements * Added imagecodecs to open external datasets --------- Co-authored-by: gityves <114951621+gityves@users.noreply.github.com> Co-authored-by: Nevexios <72894299+nevexios@users.noreply.github.com> Co-authored-by: Markus Frey <5563464+CYHSM@users.noreply.github.com> --- .coveragerc | 7 + .github/workflows/test_and_deploy.yml | 5 +- .gitignore | 2 + docs/index.rst | 4 +- docs/res/code/instance_segmentation.rst | 53 + docs/res/code/model_instance_seg.rst | 53 - docs/res/code/plugin_convert.rst | 15 - docs/res/code/plugin_model_training.rst | 1 - docs/res/code/utils.rst | 4 - .../code/{model_workers.rst => workers.rst} | 8 +- docs/res/guides/custom_model_template.rst | 46 +- docs/res/guides/detailed_walkthrough.rst | 4 +- docs/res/guides/inference_module_guide.rst | 37 +- docs/res/guides/training_module_guide.rst | 4 +- docs/res/guides/training_wnet.rst | 36 + docs/res/welcome.rst | 2 +- napari_cellseg3d/__init__.py | 2 +- napari_cellseg3d/_tests/fixtures.py | 9 +- napari_cellseg3d/_tests/test_dock_widget.py | 1 + napari_cellseg3d/_tests/test_interface.py | 8 +- .../_tests/test_labels_correction.py | 8 +- napari_cellseg3d/_tests/test_models.py | 97 ++ .../_tests/test_plugin_inference.py | 46 +- napari_cellseg3d/_tests/test_plugin_utils.py | 16 +- napari_cellseg3d/_tests/test_plugins.py | 21 + napari_cellseg3d/_tests/test_training.py | 31 +- napari_cellseg3d/_tests/test_utils.py | 50 +- .../_tests/test_weight_download.py | 6 +- napari_cellseg3d/code_models/crf.py | 231 ++++ ...stance_seg.py => instance_segmentation.py} | 112 +- .../code_models/model_framework.py | 66 +- .../code_models/models/TEMPLATE_model.py | 20 + .../code_models/models/model_SegResNet.py | 48 +- .../code_models/models/model_SwinUNetR.py | 55 +- .../code_models/models/model_TRAILMAP.py | 54 +- .../code_models/models/model_TRAILMAP_MS.py | 38 +- .../code_models/models/model_VNet.py | 37 +- .../code_models/models/model_WNet.py | 42 + .../code_models/models/model_test.py | 40 +- .../pretrained/pretrained_model_urls.json | 1 + .../code_models/models/unet/buildingblocks.py | 3 +- .../code_models/models/wnet/__init__.py | 0 .../code_models/models/wnet/model.py | 240 ++++ .../code_models/models/wnet/soft_Ncuts.py | 225 ++++ .../code_models/models/wnet/train_wnet.py | 1008 +++++++++++++++++ .../{model_workers.py => workers.py} | 649 +++++++---- napari_cellseg3d/code_plugins/plugin_base.py | 43 +- .../code_plugins/plugin_convert.py | 121 +- napari_cellseg3d/code_plugins/plugin_crf.py | 290 +++++ napari_cellseg3d/code_plugins/plugin_crop.py | 73 +- .../code_plugins/plugin_helper.py | 6 +- .../code_plugins/plugin_metrics.py | 14 +- .../code_plugins/plugin_model_inference.py | 362 +++--- .../code_plugins/plugin_model_training.py | 206 ++-- .../code_plugins/plugin_review.py | 8 +- .../code_plugins/plugin_review_dock.py | 10 +- .../code_plugins/plugin_utilities.py | 22 +- napari_cellseg3d/config.py | 115 +- .../dev_scripts/artefact_labeling.py | 23 +- napari_cellseg3d/dev_scripts/convert.py | 26 - .../dev_scripts/correct_labels.py | 16 +- napari_cellseg3d/dev_scripts/drafts.py | 15 - .../dev_scripts/evaluate_labels.py | 189 +--- napari_cellseg3d/dev_scripts/thread_test.py | 6 +- napari_cellseg3d/dev_scripts/view_brain.py | 8 - napari_cellseg3d/dev_scripts/view_sample.py | 29 - .../dev_scripts/weight_conversion.py | 234 ---- napari_cellseg3d/interface.py | 253 +++-- napari_cellseg3d/utils.py | 231 +++- notebooks/assess_instance.ipynb | 125 +- notebooks/full_plot.ipynb | 1 - notebooks/train_wnet.ipynb | 267 +++++ pyproject.toml | 23 +- requirements.txt | 7 +- setup.cfg | 60 + tox.ini | 7 +- 76 files changed, 4502 insertions(+), 1733 deletions(-) create mode 100644 .coveragerc create mode 100644 docs/res/code/instance_segmentation.rst delete mode 100644 docs/res/code/model_instance_seg.rst rename docs/res/code/{model_workers.rst => workers.rst} (78%) create mode 100644 docs/res/guides/training_wnet.rst create mode 100644 napari_cellseg3d/_tests/test_models.py create mode 100644 napari_cellseg3d/_tests/test_plugins.py create mode 100644 napari_cellseg3d/code_models/crf.py rename napari_cellseg3d/code_models/{model_instance_seg.py => instance_segmentation.py} (85%) create mode 100644 napari_cellseg3d/code_models/models/TEMPLATE_model.py create mode 100644 napari_cellseg3d/code_models/models/model_WNet.py create mode 100644 napari_cellseg3d/code_models/models/wnet/__init__.py create mode 100644 napari_cellseg3d/code_models/models/wnet/model.py create mode 100644 napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py create mode 100644 napari_cellseg3d/code_models/models/wnet/train_wnet.py rename napari_cellseg3d/code_models/{model_workers.py => workers.py} (72%) create mode 100644 napari_cellseg3d/code_plugins/plugin_crf.py delete mode 100644 napari_cellseg3d/dev_scripts/convert.py delete mode 100644 napari_cellseg3d/dev_scripts/drafts.py delete mode 100644 napari_cellseg3d/dev_scripts/view_brain.py delete mode 100644 napari_cellseg3d/dev_scripts/view_sample.py delete mode 100644 napari_cellseg3d/dev_scripts/weight_conversion.py create mode 100644 notebooks/train_wnet.ipynb diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..038f3d5a --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[report] +exclude_lines = + if __name__ == .__main__.: + +[run] +omit = + napari_cellseg3d/setup.py diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index ea0a1e46..406bf4f5 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -7,15 +7,11 @@ on: push: branches: - main - - npe2 - - cy/voronoi-otsu tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: branches: - main - - npe2 - - cy/voronoi-otsu workflow_dispatch: jobs: @@ -55,6 +51,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions +# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf # this runs the platform-specific tests declared in tox.ini - name: Test with tox diff --git a/.gitignore b/.gitignore index df43b4fa..7460d861 100644 --- a/.gitignore +++ b/.gitignore @@ -104,9 +104,11 @@ notebooks/csv_cell_plot.html notebooks/full_plot.html *.csv *.png +notebooks/instance_test.ipynb *.prof #include test data !napari_cellseg3d/_tests/res/test.tif !napari_cellseg3d/_tests/res/test.png !napari_cellseg3d/_tests/res/test_labels.tif +cov.syspath.txt diff --git a/docs/index.rst b/docs/index.rst index 7e809fbe..46c57c08 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,8 +39,8 @@ Welcome to napari-cellseg3d's documentation! res/code/plugin_convert res/code/plugin_metrics res/code/model_framework - res/code/model_workers - res/code/model_instance_seg + res/code/workers + res/code/instance_segmentation res/code/plugin_model_inference res/code/plugin_model_training res/code/utils diff --git a/docs/res/code/instance_segmentation.rst b/docs/res/code/instance_segmentation.rst new file mode 100644 index 00000000..143560c4 --- /dev/null +++ b/docs/res/code/instance_segmentation.rst @@ -0,0 +1,53 @@ +instance_segmentation.py +=========================================== + +Classes +------------- + +InstanceMethod +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::InstanceMethod + :members: __init__ + +ConnectedComponents +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::ConnectedComponents + :members: __init__ + +Watershed +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::Watershed + :members: __init__ + +VoronoiOtsu +************************************** +.. autoclass:: napari_cellseg3d.code_models.instance_segmentation::VoronoiOtsu + :members: __init__ + + +Functions +------------- + +binary_connected +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::binary_connected + +binary_watershed +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::binary_watershed + +volume_stats +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::volume_stats + +clear_small_objects +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::clear_small_objects + +to_instance +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::to_instance + +to_semantic +************************************** +.. autofunction:: napari_cellseg3d.code_models.instance_segmentation::to_semantic diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst deleted file mode 100644 index 3b323173..00000000 --- a/docs/res/code/model_instance_seg.rst +++ /dev/null @@ -1,53 +0,0 @@ -model_instance_seg.py -=========================================== - -Classes -------------- - -InstanceMethod -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::InstanceMethod - :members: __init__ - -ConnectedComponents -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::ConnectedComponents - :members: __init__ - -Watershed -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::Watershed - :members: __init__ - -VoronoiOtsu -************************************** -.. autoclass:: napari_cellseg3d.code_models.model_instance_seg::VoronoiOtsu - :members: __init__ - - -Functions -------------- - -binary_connected -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_connected - -binary_watershed -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_watershed - -volume_stats -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::volume_stats - -clear_small_objects -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::clear_small_objects - -to_instance -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_instance - -to_semantic -************************************** -.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_semantic diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index 03944510..25006d0f 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -28,18 +28,3 @@ ThresholdUtils ********************************** .. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ThresholdUtils :members: __init__ - -Functions ------------------------------------ - -save_folder -***************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_folder - -save_layer -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_layer - -show_result -**************************************** -.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result diff --git a/docs/res/code/plugin_model_training.rst b/docs/res/code/plugin_model_training.rst index 870dfd14..dc1271fc 100644 --- a/docs/res/code/plugin_model_training.rst +++ b/docs/res/code/plugin_model_training.rst @@ -18,6 +18,5 @@ Methods Attributes ********************* - .. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer :members: _viewer, worker, loss_dict, canvas, train_loss_plot, dice_metric_plot diff --git a/docs/res/code/utils.rst b/docs/res/code/utils.rst index e90ee7e0..d9fdcfa2 100644 --- a/docs/res/code/utils.rst +++ b/docs/res/code/utils.rst @@ -62,7 +62,3 @@ denormalize_y load_images ************************************** .. autofunction:: napari_cellseg3d.utils::load_images - -format_Warning -************************************** -.. autofunction:: napari_cellseg3d.utils::format_Warning diff --git a/docs/res/code/model_workers.rst b/docs/res/code/workers.rst similarity index 78% rename from docs/res/code/model_workers.rst rename to docs/res/code/workers.rst index 85f8da29..1f5167ad 100644 --- a/docs/res/code/model_workers.rst +++ b/docs/res/code/workers.rst @@ -1,4 +1,4 @@ -model_workers.py +workers.py =========================================== @@ -10,7 +10,7 @@ Class : LogSignal Attributes ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::LogSignal +.. autoclass:: napari_cellseg3d.code_models.workers::LogSignal :members: log_signal :noindex: @@ -24,7 +24,7 @@ Class : InferenceWorker Methods ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::InferenceWorker +.. autoclass:: napari_cellseg3d.code_models.workers::InferenceWorker :members: __init__, log, create_inference_dict, inference :noindex: @@ -39,6 +39,6 @@ Class : TrainingWorker Methods ************************ -.. autoclass:: napari_cellseg3d.code_models.model_workers::TrainingWorker +.. autoclass:: napari_cellseg3d.code_models.workers::TrainingWorker :members: __init__, log, train :noindex: diff --git a/docs/res/guides/custom_model_template.rst b/docs/res/guides/custom_model_template.rst index afbcd98a..b7eb65e3 100644 --- a/docs/res/guides/custom_model_template.rst +++ b/docs/res/guides/custom_model_template.rst @@ -3,35 +3,33 @@ Advanced : Declaring a custom model ============================================= -To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder: +.. warning:: + **WIP** : Adding new models is still a work in progress and will likely not work out of the box, leading to errors. -.. note:: - **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute - -:: - - def get_net(): - return ModelClass # should return the class of the model, - # for example SegResNet or UNET + Please `file an issue`_ if you would like to add a custom model and we will help you get it working. +To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder:: - def get_weights_file(): - return "weights_file.pth" # name of the weights file for the model, - # which should be in *napari_cellseg3d/models/pretrained* + class ModelTemplate_(ABC): # replace ABC with your PyTorch model class name + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) # download URL goes in pretrained_models.json + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass - def get_output(model, input): - out = model(input) # should return the model's output as [C, N, D,H,W] - # (C: channel, N, batch size, D,H,W : depth, height, width) - return out + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass - def get_validation(model, val_inputs): - val_outputs = model(val_inputs) # should return the proper type for validation - # with sliding_window_inference from MONAI - return val_outputs - +.. note:: + **WIP** : Currently you must modify :ref:`model_framework.py` as well : import your model class and add it to the ``model_dict`` attribute - def ModelClass(x1,x2...): - # your Pytorch model here... - return results # should return as [C, N, D,H,W] +.. _file an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues diff --git a/docs/res/guides/detailed_walkthrough.rst b/docs/res/guides/detailed_walkthrough.rst index 407893c2..3d06d998 100644 --- a/docs/res/guides/detailed_walkthrough.rst +++ b/docs/res/guides/detailed_walkthrough.rst @@ -1,6 +1,6 @@ .. _detailed_walkthrough: -Detailed walkthrough +Detailed walkthrough - Supervised learning =================================== The following guide will show you how to use the plugin's workflow, starting from human-labeled annotation volume, to running inference on novel volumes. @@ -109,7 +109,7 @@ of two no matter the size you choose. For optimal performance, make sure to use a power of two still, such as 64 or 120. .. important:: - Using a too large value for the size will cause memory issues. If this happens, restart napari (better handling for these situations might be added in the future). + Using a too large value for the size will cause memory issues. If this happens, restart the worker with smaller volumes. You also have the option to use data augmentation, which can improve performance and generalization. In most cases this should left enabled. diff --git a/docs/res/guides/inference_module_guide.rst b/docs/res/guides/inference_module_guide.rst index 00e67078..560282ce 100644 --- a/docs/res/guides/inference_module_guide.rst +++ b/docs/res/guides/inference_module_guide.rst @@ -7,8 +7,9 @@ This module allows you to use pre-trained segmentation algorithms (written in Py to automatically label cells. .. important:: - Currently, only inference on **3D volumes is supported**. Your image and label folders should both contain a set of - **3D image files**, currently either **.tif** or **.tiff**. + Currently, only inference on **3D volumes is supported**. If using folders, your images and labels folders + should both contain a set of **3D image files**, either **.tif** or **.tiff**. + Otherwise you may run inference on layers in napari. Currently, the following pre-trained models are available : @@ -20,6 +21,7 @@ SegResNet `3D MRI brain tumor segmentation using autoencoder regularizati TRAILMAP_MS A PyTorch implementation of the `TRAILMAP project on GitHub`_ pretrained with mesoSPIM data TRAILMAP An implementation of the `TRAILMAP project on GitHub`_ using a `3DUNet for PyTorch`_ SwinUNetR `Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images`_ +WNet `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_ ============== ================================================================================================ .. _Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation: https://arxiv.org/pdf/1606.04797.pdf @@ -27,6 +29,10 @@ SwinUNetR `Swin Transformers for Semantic Segmentation of Brain Tumors i .. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP .. _3DUnet for Pytorch: https://github.com/wolny/pytorch-3dunet .. _Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images: https://arxiv.org/abs/2201.01266 +.. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506 + +.. note:: + For WNet-specific instruction please refer to the appropriate section below. Interface and functionalities -------------------------------- @@ -67,8 +73,7 @@ Interface and functionalities * **Instance segmentation** : - | You can convert the semantic segmentation into instance labels by using either the `watershed`_ or `connected components`_ method. - | You can set the probability threshold from which a pixel is considered as a valid instance, as well as the minimum size in pixels for objects. All smaller objects will be removed. + | You can convert the semantic segmentation into instance labels by using either the Voronoi-Otsu, `Watershed`_ or `Connected Components`_ method, as detailed in :ref:`utils_module_guide`. | Instance labels will be saved (and shown if applicable) separately from other results. @@ -78,7 +83,7 @@ Interface and functionalities * **Computing objects statistics** : - You can choose to compute various stats from the labels and save them to a csv for later use. + You can choose to compute various stats from the labels and save them to a .csv for later use. This includes, for each object : @@ -98,13 +103,6 @@ Interface and functionalities In the ``notebooks`` folder you can find an example of plotting cell statistics using the result csv. -* **Viewing results** : - - | You can also select whether you'd like to **see the results** in napari afterwards. - | By default the first image processed will be displayed, but you can choose to display up to **ten at once**. - | You can also request to see the originals. - - When you are done choosing your parameters, you can press the **Start** button to begin the inference process. Once it has finished, results will be saved then displayed in napari; each output will be paired with its original. On the left side, a progress bar and a log will keep you informed on the process. @@ -115,7 +113,7 @@ On the left side, a progress bar and a log will keep you informed on the process | ``{original_name}_{model}_{date & time}_pred{id}.file_ext`` | For example, using a VNet on the third image of a folder, called "somatomotor.tif" will yield the following name : | *somatomotor_VNet_2022_04_06_15_49_42_pred3.tif* - | Instance labels will have the "Instance_seg" prefix appened to the name. + | Instance labels will have the "Instance_seg" prefix appended to the name. .. hint:: @@ -128,8 +126,19 @@ On the left side, a progress bar and a log will keep you informed on the process .. note:: You can save the log after the worker is finished to easily remember which parameters you ran inference with. +WNet +-------------------------------- + +The WNet model, from the paper `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_, is a fully unsupervised model that can be used to segment images without any labels. +It clusters pixels based on brightness, and can be used to segment cells in a variety of modalities. +Its use and available options are similar to the above models, with a few differences : +.. note:: + | Our provided, pre-trained model should use an input size of 64x64x64. As such, window inference is always enabled + | and set to 64. If you want to use a different size, you will have to train your own model using the provided notebook. +All it requires are images; for nucleus segmentation, it is recommended to use 2 classes (default). + Source code -------------------------------- * :doc:`../code/plugin_model_inference` * :doc:`../code/model_framework` -* :doc:`../code/model_workers` +* :doc:`../code/workers` diff --git a/docs/res/guides/training_module_guide.rst b/docs/res/guides/training_module_guide.rst index fb8992d2..1038dc6d 100644 --- a/docs/res/guides/training_module_guide.rst +++ b/docs/res/guides/training_module_guide.rst @@ -74,6 +74,8 @@ The training module is comprised of several tabs. * The **batch size** (larger means quicker training and possibly better performance but increased memory usage) * The **number of epochs** (a possibility is to start with 60 epochs, and decrease or increase depending on performance.) * The **epoch interval** for validation (for example, if set to two, the module will use the validation dataset to evaluate the model with the dice metric every two epochs.) +* The **schedular patience**, which is the amount of epoch at a plateau that is waited for until the learning rate is reduced +* The **scheduler factor**, which is the factor by which to reduce the learning rate once a plateau is reached * Whether to use deterministic training, and the seed to use. .. note:: @@ -126,4 +128,4 @@ Source code -------------------------------- * :doc:`../code/plugin_model_training` * :doc:`../code/model_framework` -* :doc:`../code/model_workers` +* :doc:`../code/workers` diff --git a/docs/res/guides/training_wnet.rst b/docs/res/guides/training_wnet.rst new file mode 100644 index 00000000..ecd20542 --- /dev/null +++ b/docs/res/guides/training_wnet.rst @@ -0,0 +1,36 @@ +.. _training_wnet: + +WNet model training +=================== + +This plugin provides a reimplemented, custom version of the WNet model from `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_. +In order to train your own model, you may use the provided Jupyter notebook; support for in-plugin training might be added in the future. + +The WNet uses brightness to cluster objects vs background; to get the most out of the model please use image regions with minimal +artifacts. You may then train one of the supervised models in order to achieve more resilient segmentation if you have many artifacts. + +The WNet should not require a very large amount of data to train, but during inference images should be similar to those +the model was trained on; you can retrain from our pretrained model to your set of images to quickly reach good performance. + +The model has two losses, the SoftNCut loss which clusters pixels according to brightness, and a reconstruction loss, either +Mean Square Error (MSE) or Binary Cross Entropy (BCE). +Unlike the original paper, these losses are added in a weighted sum and the backward pass is performed for the whole model at once. +The SoftNcuts is bounded between 0 and 1; the MSE may take large values. + +For good performance, one should wait for the SoftNCut to reach a plateau, the reconstruction loss must also diminish but it's generally less critical. + + +Common issues troubleshooting +------------------------------ +If you do not find a satisfactory answer here, please `open an issue`_ ! + +- **The NCuts loss explodes after a few epochs** : Lower the learning rate + +- **The NCuts loss does not converge and is unstable** : + The normalization step might not be adapted to your images. Disable normalization and change intensity_sigma according to the distribution of values in your image; for reference, by default images are remapped to values between 0 and 100, and intensity_sigma=1. + +- **Reconstruction (decoder) performance is poor** : switch to BCE and set the scaling factor of the reconstruction loss ot 0.5, OR adjust the weight of the MSE loss to make it closer to 1. + + +.. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506 +.. _open an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 892549a8..8625e01f 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -82,7 +82,7 @@ Then go into Plugins > napari-cellseg3d, and choose which tool to use: - **Utilities**: This module allows you to use several utilities, e.g. to crop your volumes and labels, compute prediction scores or convert labels - **Help/About...** : Quick access to version info, Github page and docs -See the documentation for links to detailed guides regarding the usage of the modules. +See above for links to detailed guides regarding the usage of the modules. Acknowledgments & References --------------------------------------------- diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index 11e8de0e..be8123e4 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.2rc6" +__version__ = "0.0.3rc1" diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py index b40a77d3..b3044799 100644 --- a/napari_cellseg3d/_tests/fixtures.py +++ b/napari_cellseg3d/_tests/fixtures.py @@ -1,7 +1,7 @@ -import warnings - from qtpy.QtWidgets import QTextEdit +from napari_cellseg3d.utils import LOGGER as logger + class LogFixture(QTextEdit): """Fixture for testing, replaces napari_cellseg3d.interface.Log in model_workers during testing""" @@ -13,4 +13,7 @@ def print_and_log(self, text, printing=None): print(text) def warn(self, warning): - warnings.warn(warning) + logger.warning(warning) + + def error(self, e): + raise (e) diff --git a/napari_cellseg3d/_tests/test_dock_widget.py b/napari_cellseg3d/_tests/test_dock_widget.py index 7737e540..8063c92b 100644 --- a/napari_cellseg3d/_tests/test_dock_widget.py +++ b/napari_cellseg3d/_tests/test_dock_widget.py @@ -11,6 +11,7 @@ def test_prepare(make_napari_viewer): viewer = make_napari_viewer() viewer.add_image(image) widget = Datamanager(viewer) + viewer.window.add_dock_widget(widget) widget.prepare(path_image, ".tif", "", False) diff --git a/napari_cellseg3d/_tests/test_interface.py b/napari_cellseg3d/_tests/test_interface.py index be811721..08e0e675 100644 --- a/napari_cellseg3d/_tests/test_interface.py +++ b/napari_cellseg3d/_tests/test_interface.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.interface import Log +from napari_cellseg3d.interface import AnisotropyWidgets, Log def test_log(qtbot): @@ -12,3 +12,9 @@ def test_log(qtbot): assert log.toPlainText() == "\ntest2" qtbot.add_widget(log) + + +def test_zoom_factor(): + resolution = [10.0, 10.0, 5.0] + zoom = AnisotropyWidgets.anisotropy_zoom_factor(resolution) + assert zoom == [1, 1, 0.5] diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py index c65d7402..b4f13238 100644 --- a/napari_cellseg3d/_tests/test_labels_correction.py +++ b/napari_cellseg3d/_tests/test_labels_correction.py @@ -37,16 +37,16 @@ def test_correct_labels(): ) -def test_relabel(make_napari_viewer): - viewer = make_napari_viewer() +def test_relabel(): cl.relabel( str(image_path), str(labels_path), go_fast=True, - viewer=viewer, test=True, ) def test_evaluate_model_performance(): - el.evaluate_model_performance(labels, labels, print_details=True) + el.evaluate_model_performance( + labels, labels, print_details=True, visualize=False + ) diff --git a/napari_cellseg3d/_tests/test_models.py b/napari_cellseg3d/_tests/test_models.py new file mode 100644 index 00000000..ebb3a614 --- /dev/null +++ b/napari_cellseg3d/_tests/test_models.py @@ -0,0 +1,97 @@ +import numpy as np +import torch +from numpy.random import PCG64, Generator + +from napari_cellseg3d.code_models.crf import ( + CRFWorker, + correct_shape_for_crf, + crf_batch, + crf_with_config, +) +from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss +from napari_cellseg3d.config import MODEL_LIST, CRFConfig + +rand_gen = Generator(PCG64(12345)) + + +def test_correct_shape_for_crf(): + test = rand_gen.random(size=(1, 1, 8, 8, 8)) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + test = rand_gen.random(size=(8, 8, 8)) + assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) + + +def test_model_list(): + for model_name in MODEL_LIST: + # if model_name=="test": + # continue + dims = 128 + test = MODEL_LIST[model_name]( + input_img_size=[dims, dims, dims], + in_channels=1, + out_channels=1, + dropout_prob=0.3, + ) + assert isinstance(test, MODEL_LIST[model_name]) + + +def test_soft_ncuts_loss(): + dims = 8 + labels = torch.rand([1, 1, dims, dims, dims]) + + loss = SoftNCutsLoss( + data_shape=[dims, dims, dims], + device="cpu", + intensity_sigma=4, + spatial_sigma=4, + radius=2, + ) + + res = loss.forward(labels, labels) + assert isinstance(res, torch.Tensor) + assert 0 <= res <= 1 # ASSUMES NUMBER OF CLASS IS 2, NOT CORRECT IF K>2 + + +def test_crf_batch(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_batch( + np.array([mock_image, mock_image, mock_image]), + np.array([mock_label, mock_label, mock_label]), + sa=config.sa, + sb=config.sb, + sg=config.sg, + w1=config.w1, + w2=config.w2, + ) + + assert result.shape == (3, 2, dims, dims, dims) + + +def test_crf_config(): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + config = CRFConfig() + + result = crf_with_config(mock_image, mock_label, config) + assert result.shape == mock_label.shape + + +def test_crf_worker(qtbot): + dims = 8 + mock_image = rand_gen.random(size=(1, dims, dims, dims)) + mock_label = rand_gen.random(size=(2, dims, dims, dims)) + assert len(mock_label.shape) == 4 + crf = CRFWorker([mock_image], [mock_label]) + + def on_yield(result): + assert len(result.shape) == 4 + assert len(mock_label.shape) == 4 + assert result.shape[-3:] == mock_label.shape[-3:] + + result = next(crf._run_crf_job()) + on_yield(result) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 212c4120..0258f243 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -3,8 +3,14 @@ from tifffile import imread from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_models.instance_segmentation import ( + INSTANCE_SEGMENTATION_METHOD_LIST, +) from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer +from napari_cellseg3d.code_plugins.plugin_model_inference import ( + InferenceResult, + Inferer, +) from napari_cellseg3d.config import MODEL_LIST @@ -28,14 +34,36 @@ def test_inference(make_napari_viewer, qtbot): assert widget.check_ready() - MODEL_LIST["test"] = TestModel - widget.model_choice.addItem("test") - widget.setCurrentIndex(-1) + widget.model_choice.setCurrentText("WNet") + widget._restrict_window_size_for_model() + assert widget.window_infer_box.isChecked() + assert widget.window_size_choice.currentText() == "64" - # widget.start() # takes too long on Github Actions - # assert widget.worker is not None + test_model_name = "test" + MODEL_LIST[test_model_name] = TestModel + widget.model_choice.addItem(test_model_name) + widget.model_choice.setCurrentText(test_model_name) - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: - # blocker.connect(widget.worker.errored) + widget.worker_config = widget._set_worker_config() + assert widget.worker_config is not None + assert widget.model_info is not None + widget.window_infer_box.setChecked(False) + worker = widget._create_worker_from_config(widget.worker_config) - # assert len(viewer.layers) == 2 + assert worker.config is not None + assert worker.config.model_info is not None + worker.config.layer = viewer.layers[0].data + worker.config.post_process_config.instance.enabled = True + worker.config.post_process_config.instance.method = ( + INSTANCE_SEGMENTATION_METHOD_LIST["Watershed"]() + ) + + assert worker.config.layer is not None + worker.log_parameters() + + res = next(worker.inference()) + assert isinstance(res, InferenceResult) + assert res.result.shape == (8, 8, 8) + assert res.instance_labels.shape == (8, 8, 8) + + widget.on_yield(res) diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 0abcf387..60c25ccc 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -1,24 +1,24 @@ -from pathlib import Path - import numpy as np -from tifffile import imread +from numpy.random import PCG64, Generator from napari_cellseg3d.code_plugins.plugin_utilities import ( UTILITIES_WIDGETS, Utilities, ) +rand_gen = Generator(PCG64(12345)) + def test_utils_plugin(make_napari_viewer): view = make_napari_viewer() widget = Utilities(view) - im_path = str(Path(__file__).resolve().parent / "res/test.tif") - image = imread(im_path) - view.add_image(image) - view.add_labels(image.astype(np.uint8)) + image = rand_gen.random((10, 10, 10)).astype(np.uint8) + image_layer = view.add_image(image, name="image") + label_layer = view.add_labels(image.astype(np.uint8), name="labels") view.window.add_dock_widget(widget) + view.dims.ndisplay = 3 for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): widget.utils_choice.setCurrentIndex(i) assert isinstance( @@ -29,4 +29,6 @@ def test_utils_plugin(make_napari_viewer): menu = widget.utils_widgets[i].instance_widgets.method_choice menu.setCurrentIndex(menu.currentIndex() + 1) + assert len(image_layer.data.shape) == 3 + assert len(label_layer.data.shape) == 3 widget.utils_widgets[i]._start() diff --git a/napari_cellseg3d/_tests/test_plugins.py b/napari_cellseg3d/_tests/test_plugins.py new file mode 100644 index 00000000..c58d26af --- /dev/null +++ b/napari_cellseg3d/_tests/test_plugins.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from napari_cellseg3d import plugins +from napari_cellseg3d.code_plugins import plugin_metrics as m + + +def test_all_plugins_import(make_napari_viewer): + plugins.napari_experimental_provide_dock_widget() + + +def test_plugin_metrics(make_napari_viewer): + viewer = make_napari_viewer() + w = m.MetricsUtils(viewer=viewer, parent=None) + viewer.window.add_dock_widget(w) + + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + labels_path = im_path + + w.image_filewidget.text_field = im_path + w.labels_filewidget.text_field = labels_path + w.compute_dice() diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 21731ba1..e7f1e07b 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -3,7 +3,10 @@ from napari_cellseg3d import config from napari_cellseg3d._tests.fixtures import LogFixture from napari_cellseg3d.code_models.models.model_test import TestModel -from napari_cellseg3d.code_plugins.plugin_model_training import Trainer +from napari_cellseg3d.code_plugins.plugin_model_training import ( + Trainer, + TrainingReport, +) from napari_cellseg3d.config import MODEL_LIST @@ -32,15 +35,31 @@ def test_training(make_napari_viewer, qtbot): ################# # Training is too long to test properly this way. Do not use on Github ################# - MODEL_LIST["test"] = TestModel() + MODEL_LIST["test"] = TestModel widget.model_choice.addItem("test") widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) - # widget.start() - # assert widget.worker is not None - - # with qtbot.waitSignal(signal=widget.worker.finished, timeout=10000, raising=False) as blocker: # wait only for 60 seconds. + worker_config = widget._set_worker_config() + worker = widget._create_worker_from_config(worker_config) + worker.config.train_data_dict = [{"image": im_path, "label": im_path}] + worker.config.val_data_dict = [{"image": im_path, "label": im_path}] + worker.config.max_epochs = 1 + worker.log_parameters() + res = next(worker.train()) + + assert isinstance(res, TrainingReport) + + # def on_error(e): + # print(e) + # assert False + # + # with qtbot.waitSignal( + # signal=widget.worker.finished, timeout=10000, raising=True + # ) as blocker: # blocker.connect(widget.worker.errored) + # widget.worker.error_signal.connect(on_error) + # widget.worker.train() + # assert widget.worker is not None def test_update_loss_plot(make_napari_viewer): diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index dc57b940..dc680b35 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -1,15 +1,15 @@ -import os -import warnings +from functools import partial +from pathlib import Path import numpy as np -import pytest import torch from napari_cellseg3d import utils +from napari_cellseg3d.dev_scripts import thread_test def test_fill_list_in_between(): - list = [1, 2, 3, 4, 5, 6] + test_list = [1, 2, 3, 4, 5, 6] res = [ 1, "", @@ -31,7 +31,11 @@ def test_fill_list_in_between(): "", ] - assert utils.fill_list_in_between(list, 2, "") == res + assert utils.fill_list_in_between(test_list, 2, "") == res + + fill = partial(utils.fill_list_in_between, n=2, fill_value="") + + assert fill(test_list) == res def test_align_array_sizes(): @@ -79,15 +83,15 @@ def test_get_padding_dim(): tensor = torch.randn(2000, 30, 40) size = tensor.size() - warn = warnings.warn( - "Warning : a very large dimension for automatic padding has been computed.\n" - "Ensure your images are of an appropriate size and/or that you have enough memory." - "The padding value is currently 2048." - ) - + # warn = logger.warning( + # "Warning : a very large dimension for automatic padding has been computed.\n" + # "Ensure your images are of an appropriate size and/or that you have enough memory." + # "The padding value is currently 2048." + # ) + # pad = utils.get_padding_dim(size) - - pytest.warns(warn, (lambda: utils.get_padding_dim(size))) + # + # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) assert pad == [2048, 32, 64] @@ -106,11 +110,19 @@ def test_normalize_x(): def test_parse_default_path(): - user_path = os.path.expanduser("~") - assert utils.parse_default_path([None]) == user_path + user_path = Path().home() + assert utils.parse_default_path([None]) == str(user_path) + + test_path = "C:/test/test" + path = [test_path, None, None] + assert utils.parse_default_path(path) == test_path + + long_path = "D:/very/long/path/what/a/bore/ifonlytherewassomethingtohelpmenottypeitiallthetime" + path = [test_path, None, None, long_path, ""] + assert utils.parse_default_path(path) == long_path - path = ["C:/test/test", None, None] - assert utils.parse_default_path(path) == "C:/test/test" - path = ["C:/test/test", None, None, "D:/very/long/path/what/a/bore", ""] - assert utils.parse_default_path(path) == "D:/very/long/path/what/a/bore" +def test_thread_test(make_napari_viewer): + viewer = make_napari_viewer() + w = thread_test.create_connected_widget(viewer) + viewer.window.add_dock_widget(w) diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py index d8886a56..a00ab1de 100644 --- a/napari_cellseg3d/_tests/test_weight_download.py +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -1,5 +1,5 @@ -from napari_cellseg3d.code_models.model_workers import ( - WEIGHTS_DIR, +from napari_cellseg3d.code_models.workers import ( + PRETRAINED_WEIGHTS_DIR, WeightsDownloader, ) @@ -8,6 +8,6 @@ def test_weight_download(): downloader = WeightsDownloader() downloader.download_weights("test", "test.pth") - result_path = WEIGHTS_DIR / "test.pth" + result_path = PRETRAINED_WEIGHTS_DIR / "test.pth" assert result_path.is_file() diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py new file mode 100644 index 00000000..b362246a --- /dev/null +++ b/napari_cellseg3d/code_models/crf.py @@ -0,0 +1,231 @@ +""" +Implements the CRF post-processing step for the W-Net. +Inspired by https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + +Also uses research from: +Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials +Philipp Krähenbühl and Vladlen Koltun +NIPS 2011 + +Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf. +""" +from warnings import warn + +try: + import pydensecrf.densecrf as dcrf + from pydensecrf.utils import ( + create_pairwise_bilateral, + create_pairwise_gaussian, + unary_from_softmax, + ) + + CRF_INSTALLED = True +except ImportError: + warn( + "pydensecrf not installed, CRF post-processing will not be available. " + "Please install by running pip install cellseg3d[crf]", + stacklevel=1, + ) + CRF_INSTALLED = False + + +import numpy as np +from napari.qt.threading import GeneratorWorker + +from napari_cellseg3d.config import CRFConfig +from napari_cellseg3d.utils import LOGGER as logger + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Philipp Krähenbühl", + "Vladlen Koltun", + "Liang-Chieh Chen", + "George Papandreou", + "Iasonas Kokkinos", + "Kevin Murphy", + "Alan L. Yuille", + "Xide Xia", + "Brian Kulis", + "Lucas Beyer", +] + + +def correct_shape_for_crf(image, desired_dims=4): + logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}") + logger.debug(f"Image shape: {image.shape}") + if len(image.shape) > desired_dims: + # if image.shape[0] > 1: + # raise ValueError( + # f"Image shape {image.shape} might have several channels" + # ) + image = np.squeeze(image, axis=0) + elif len(image.shape) < desired_dims: + image = np.expand_dims(image, axis=0) + logger.debug(f"Corrected image shape: {image.shape}") + return image + + +def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5): + """CRF post-processing step for the W-Net, applied to a batch of images. + + Args: + images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images. + probs (np.ndarray): Array of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (N, K, H, W, D) containing the refined class probabilities for each pixel. + """ + if not CRF_INSTALLED: + return None + + return np.stack( + [ + crf(images[i], probs[i], sa, sb, sg, w1, w2, n_iter=n_iter) + for i in range(images.shape[0]) + ], + axis=0, + ) + + +def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5): + """Implements the CRF post-processing step for the W-Net. + Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506. + Implemented using the pydensecrf library. + + Args: + image (np.ndarray): Array of shape (C, H, W, D) containing the input image. + prob (np.ndarray): Array of shape (K, H, W, D) containing the predicted class probabilities for each pixel. + sa (float): alpha standard deviation, the scale of the spatial part of the appearance/bilateral kernel. + sb (float): beta standard deviation, the scale of the color part of the appearance/bilateral kernel. + sg (float): gamma standard deviation, the scale of the smoothness/gaussian kernel. + w1 (float): weight of the appearance/bilateral kernel. + w2 (float): weight of the smoothness/gaussian kernel. + + Returns: + np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel. + """ + + if not CRF_INSTALLED: + return None + + d = dcrf.DenseCRF( + image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0] + ) + # print(f"Image shape : {image.shape}") + # print(f"Prob shape : {prob.shape}") + # d = dcrf.DenseCRF(262144, 3) # npoints, nlabels + + # Get unary potentials from softmax probabilities + U = unary_from_softmax(prob) + d.setUnaryEnergy(U) + + # Generate pairwise potentials + featsGaussian = create_pairwise_gaussian( + sdims=(sg, sg, sg), shape=image.shape[1:] + ) # image.shape) + featsBilateral = create_pairwise_bilateral( + sdims=(sa, sa, sa), + schan=tuple([sb for i in range(image.shape[0])]), + img=image, + chdim=-1, + ) + + # Add pairwise potentials to the CRF + compat = np.ones(prob.shape[0], dtype=np.float32) - np.diag( + [1 for i in range(prob.shape[0])] + # , dtype=np.float32 + ) + d.addPairwiseEnergy(featsGaussian, compat=compat.astype(np.float32) * w2) + d.addPairwiseEnergy(featsBilateral, compat=compat.astype(np.float32) * w1) + + # Run inference + Q = d.inference(n_iter) + + return np.array(Q).reshape( + (prob.shape[0], image.shape[1], image.shape[2], image.shape[3]) + ) + + +def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info): + if config is None: + config = CRFConfig() + if image.shape[-3:] != prob.shape[-3:]: + raise ValueError( + f"Image and probability shapes do not match: {image.shape} vs {prob.shape}" + f" (expected {image.shape[-3:]} == {prob.shape[-3:]})" + ) + + image = correct_shape_for_crf(image) + prob = correct_shape_for_crf(prob) + + if log is not None: + log("Running CRF post-processing step") + log(f"Image shape : {image.shape}") + log(f"Labels shape : {prob.shape}") + + return crf( + image, + prob, + config.sa, + config.sb, + config.sg, + config.w1, + config.w2, + config.n_iters, + ) + + +class CRFWorker(GeneratorWorker): + """Worker for the CRF post-processing step for the W-Net.""" + + def __init__( + self, + images_list: list, + labels_list: list, + config: CRFConfig = None, + log=None, + ): + super().__init__(self._run_crf_job) + + self.images = images_list + self.labels = labels_list + if config is None: + self.config = CRFConfig() + else: + self.config = config + self.log = log + + def _run_crf_job(self): + """Runs the CRF post-processing step for the W-Net.""" + if not CRF_INSTALLED: + raise ImportError("pydensecrf is not installed.") + + if len(self.images) != len(self.labels): + raise ValueError("Number of images and labels must be the same.") + + for i in range(len(self.images)): + if self.images[i].shape[-3:] != self.labels[i].shape[-3:]: + raise ValueError("Image and labels must have the same shape.") + + im = correct_shape_for_crf(self.images[i]) + prob = correct_shape_for_crf(self.labels[i]) + + logger.debug(f"image shape : {im.shape}") + logger.debug(f"labels shape : {prob.shape}") + + yield crf( + im, + prob, + self.config.sa, + self.config.sb, + self.config.sg, + self.config.w1, + self.config.w2, + n_iter=self.config.n_iters, + ) diff --git a/napari_cellseg3d/code_models/model_instance_seg.py b/napari_cellseg3d/code_models/instance_segmentation.py similarity index 85% rename from napari_cellseg3d/code_models/model_instance_seg.py rename to napari_cellseg3d/code_models/instance_segmentation.py index 60f8bbda..535bd429 100644 --- a/napari_cellseg3d/code_models/model_instance_seg.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -1,4 +1,6 @@ +import abc from dataclasses import dataclass +from functools import partial from typing import List import numpy as np @@ -7,15 +9,17 @@ from skimage.measure import label, regionprops from skimage.morphology import remove_small_objects from skimage.segmentation import watershed - -# from skimage.measure import mesh_surface_area -# from skimage.measure import marching_cubes from tifffile import imread +# local from napari_cellseg3d import interface as ui from napari_cellseg3d.utils import LOGGER as logger from napari_cellseg3d.utils import fill_list_in_between, sphericity_axis +# from skimage.measure import marching_cubes +# from skimage.measure import mesh_surface_area + + # from napari_cellseg3d.utils import sphericity_volume_area # list of methods : @@ -48,6 +52,18 @@ def __init__( self.function = function self.counters: List[ui.DoubleIncrementCounter] = [] self.sliders: List[ui.Slider] = [] + self._setup_widgets( + num_counters, num_sliders, widget_parent=widget_parent + ) + + def _setup_widgets(self, num_counters, num_sliders, widget_parent=None): + """Initializes the needed widgets for the instance segmentation method, adding sliders and counters to the + instance segmentation widget. + Args: + num_counters: Number of DoubleIncrementCounter UI elements needed to set the parameters of the function + num_sliders: Number of Slider UI elements needed to set the parameters of the function + widget_parent: parent for the declared widgets + """ if num_sliders > 0: for i in range(num_sliders): widget = f"slider_{i}" @@ -60,7 +76,7 @@ def __init__( 1, divide_factor=100, text_label="", - parent=None, + parent=widget_parent, ), ) self.sliders.append(getattr(self, widget)) @@ -71,12 +87,37 @@ def __init__( setattr( self, widget, - ui.DoubleIncrementCounter(label="", parent=None), + ui.DoubleIncrementCounter( + text_label="", parent=widget_parent + ), ) self.counters.append(getattr(self, widget)) + @abc.abstractmethod def run_method(self, image): - raise NotImplementedError("Must be defined in child classes") + raise NotImplementedError() + + def _make_list_from_channels( + self, image + ): # TODO(cyril) : adapt to batch dimension + if len(image.shape) > 4: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at most 4 dimensions (CHWD)" + ) + if len(image.shape) < 2: + raise ValueError( + f"Image has {len(image.shape)} dimensions, but should have at least 2 dimensions (HW)" + ) + if len(image.shape) == 4: + image = np.squeeze(image) + if len(image.shape) == 4: + return [im for im in image] + return [image] + + def run_method_on_channels(self, image): + image_list = self._make_list_from_channels(image) + result = np.array([self.run_method(im) for im in image_list]) + return result.squeeze() @dataclass @@ -125,6 +166,8 @@ def voronoi_otsu( Voronoi-Otsu labeling from pyclesperanto. BASED ON CODE FROM : napari_pyclesperanto_assistant by Robert Haase https://github.com/clEsperanto/napari_pyclesperanto_assistant + Original code at : + https://github.com/clEsperanto/pyclesperanto_prototype/blob/master/pyclesperanto_prototype/_tier9/_voronoi_otsu_labeling.py Args: volume (np.ndarray): volume to segment @@ -159,8 +202,6 @@ def binary_connected( volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 - scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) - """ logger.debug( f"Running connected components segmentation with thres={thres} and thres_small={thres_small}" @@ -272,25 +313,23 @@ def clear_small_objects(image, threshold, is_file_path=False): return result -def to_instance(image, is_file_path=False): - """Converts a **ground-truth** label to instance (unique id per object) labels. Does not remove small objects. - - Args: - image: image or path to image - is_file_path: if True, will consider ``image`` to be a string containing a path to a file, if not treats it as an image data array. - - Returns: resulting converted labels - - """ - if is_file_path: - image = [imread(image)] - # image = image.compute() - - result = binary_watershed( - image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 - ) # FIXME add params from utils plugin - - return result +# def to_instance(image, is_file_path=False): +# """Converts a **ground-truth** label to instance (unique id per object) labels. Does not remove small objects. +# +# Args: +# image: image or path to image +# is_file_path: if True, will consider ``image`` to be a string containing a path to a file, if not treats it as an image data array. +# +# Returns: resulting converted labels +# +# """ +# if is_file_path: +# image = [imread(image)] +# image = image.compute() +# +# return binary_watershed( +# image, thres_small=0, thres_seeding=0.3, rem_seed_thres=0 +# ) def to_semantic(image, is_file_path=False): @@ -308,8 +347,7 @@ def to_semantic(image, is_file_path=False): # image = image.compute() image[image >= 1] = 1 - result = image.astype(np.uint16) - return result + return image.astype(np.uint16) def volume_stats(volume_image): @@ -358,8 +396,10 @@ def sphericity(region): volume = [region.area for region in properties] - def fill(lst, n=len(properties) - 1): - return fill_list_in_between(lst, n, "") + # def fill(lst, n=len(properties) - 1): + # return fill_list_in_between(lst, n, "") + + fill = partial(fill_list_in_between, n=len(properties) - 1, fill_value="") if len(volume_image.flatten()) != 0: ratio = fill([np.sum(volume) / len(volume_image.flatten())]) @@ -369,7 +409,7 @@ def fill(lst, n=len(properties) - 1): return ImageStats( volume, [region.centroid[0] for region in properties], - [region.centroid[0] for region in properties], + [region.centroid[1] for region in properties], [region.centroid[2] for region in properties], sphericity_ax, fill([volume_image.shape]), @@ -392,13 +432,13 @@ def __init__(self, widget_parent=None): widget_parent=widget_parent, ) - self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[0].label.setText("Foreground probability threshold") self.sliders[ 0 ].tooltips = "Probability threshold for foreground object" self.sliders[0].setValue(50) - self.sliders[1].text_label.setText("Seed probability threshold") + self.sliders[1].label.setText("Seed probability threshold") self.sliders[1].tooltips = "Probability threshold for seeding" self.sliders[1].setValue(90) @@ -438,7 +478,7 @@ def __init__(self, widget_parent=None): widget_parent=widget_parent, ) - self.sliders[0].text_label.setText("Foreground probability threshold") + self.sliders[0].label.setText("Foreground probability threshold") self.sliders[ 0 ].tooltips = "Probability threshold for foreground object" @@ -578,7 +618,7 @@ def run_method(self, volume): """ method = self.methods[self.method_choice.currentText()] - return method.run_method(volume) + return method.run_method_on_channels(volume) INSTANCE_SEGMENTATION_METHOD_LIST = { diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 2cc4265e..ddd9cd28 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -1,9 +1,11 @@ -import warnings from pathlib import Path +from typing import TYPE_CHECKING -import napari import torch +if TYPE_CHECKING: + import napari + # Qt from qtpy.QtWidgets import QProgressBar, QSizePolicy @@ -12,7 +14,6 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder -warnings.formatwarning = utils.format_Warning logger = utils.LOGGER @@ -78,7 +79,7 @@ def __init__( # ) self.model_choice = ui.DropdownMenu( - sorted(self.available_models.keys()), label="Model name" + sorted(self.available_models.keys()), text_label="Model name" ) self.weights_filewidget = ui.FilePathWidget( @@ -128,18 +129,18 @@ def save_log(self): path = self.results_path if len(log) != 0: - with open( + with Path.open( path + f"/Log_report_{utils.get_date_time()}.txt", "x", ) as f: f.write(log) f.close() else: - warnings.warn( + logger.warning( "No job has been completed yet, please start one or re-open the log window." ) else: - warnings.warn(f"No logger defined : Log is {self.log}") + logger.warning(f"No logger defined : Log is {self.log}") def save_log_to_path(self, path): """Saves the worker log to a specific path. Cannot be used with connect. @@ -154,14 +155,14 @@ def save_log_to_path(self, path): ) if len(log) != 0: - with open( - path, + with Path.open( + Path(path), "x", ) as f: f.write(log) f.close() else: - warnings.warn( + logger.warning( "No job has been completed yet, please start one or re-open the log window." ) @@ -170,7 +171,7 @@ def display_status_report(self): (usually when starting a worker)""" # if self.container_report is None or self.log is None: - # warnings.warn( + # logger.warning( # "Status report widget has been closed. Trying to re-instantiate..." # ) # self.container_report = QWidget() @@ -272,6 +273,14 @@ def get_available_models(): # self.lbl_model_path.setText(self.model_path) # # self.update_default() + def _update_weights_path(self, file): + if file[0] == self._default_weights_folder: + return + if file is not None and file[0] != "": + self.weights_config.path = file[0] + self.weights_filewidget.text_field.setText(file[0]) + self._default_weights_folder = str(Path(file[0]).parent) + def _load_weights_path(self): """Show file dialog to set :py:attr:`model_path`""" @@ -280,14 +289,9 @@ def _load_weights_path(self): file = ui.open_file_dialog( self, [self._default_weights_folder], - filetype="Weights file (*.pth)", + file_extension="Weights file (*.pth)", ) - if file[0] == self._default_weights_folder: - return - if file is not None and file[0] != "": - self.weights_config.path = file[0] - self.weights_filewidget.text_field.setText(file[0]) - self._default_weights_folder = str(Path(file[0]).parent) + self._update_weights_path(file) @staticmethod def get_device(show=True): @@ -307,31 +311,5 @@ def empty_cuda_cache(self): torch.cuda.empty_cache() logger.info("Attempt complete : Cache emptied") - # def update_default(self): # TODO add custom models - # """Update default path for smoother file dialogs, here with :py:attr:`~model_path` included""" - # - # if len(self.images_filepaths) != 0: - # from_images = str(Path(self.images_filepaths[0]).parent) - # else: - # from_images = None - # - # if len(self.labels_filepaths) != 0: - # from_labels = str(Path(self.labels_filepaths[0]).parent) - # else: - # from_labels = None - # - # possible_paths = [ - # path - # for path in [ - # from_images, - # from_labels, - # # self.model_path, - # self.results_path, - # ] - # if path is not None - # ] - # self._default_folders = possible_paths - # update if model_path is used again - def _build(self): raise NotImplementedError("Should be defined in children classes") diff --git a/napari_cellseg3d/code_models/models/TEMPLATE_model.py b/napari_cellseg3d/code_models/models/TEMPLATE_model.py new file mode 100644 index 00000000..f68e5f4f --- /dev/null +++ b/napari_cellseg3d/code_models/models/TEMPLATE_model.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class ModelTemplate_(ABC): + use_default_training = True # not needed for now, will serve for WNet training if added to the plugin + weights_file = ( + "model_template.pth" # specify the file name of the weights file only + ) + + @abstractmethod + def __init__( + self, input_image_size, in_channels=1, out_channels=1, **kwargs + ): + """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" + pass + + @abstractmethod + def forward(self, x): + """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" + pass diff --git a/napari_cellseg3d/code_models/models/model_SegResNet.py b/napari_cellseg3d/code_models/models/model_SegResNet.py index 8856e18d..8b6e6e65 100644 --- a/napari_cellseg3d/code_models/models/model_SegResNet.py +++ b/napari_cellseg3d/code_models/models/model_SegResNet.py @@ -1,21 +1,33 @@ from monai.networks.nets import SegResNetVAE -def get_net(input_image_size, out_channels=1, dropout_prob=0.3): - return SegResNetVAE( - input_image_size, out_channels=out_channels, dropout_prob=dropout_prob - ) - - -def get_weights_file(): - return "SegResNet.pth" - - -def get_output(model, input): - out = model(input)[0] - return out - - -def get_validation(model, val_inputs): - val_outputs = model(val_inputs) - return val_outputs[0] +class SegResNet_(SegResNetVAE): + use_default_training = True + weights_file = "SegResNet.pth" + + def __init__( + self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs + ): + super().__init__( + input_img_size, + out_channels=out_channels, + dropout_prob=dropout_prob, + ) + + def forward(self, x): + res = SegResNetVAE.forward(self, x) + # logger.debug(f"SegResNetVAE.forward: {res[0].shape}") + return res[0] + + def get_model_test(self, size): + return SegResNetVAE( + size, in_channels=1, out_channels=1, dropout_prob=0.3 + ) + + # def get_output(model, input): + # out = model(input)[0] + # return out + + # def get_validation(model, val_inputs): + # val_outputs = model(val_inputs) + # return val_outputs[0] diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 532aeb89..2d7b5ef6 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,25 +1,44 @@ -import torch from monai.networks.nets import SwinUNETR +from napari_cellseg3d.utils import LOGGER -def get_weights_file(): - return "Swin64_best_metric.pth" +logger = LOGGER -def get_net(img_size, use_checkpoint=True): - return SwinUNETR( - img_size, +class SwinUNETR_(SwinUNETR): + use_default_training = True + weights_file = "Swin64_best_metric.pth" + + def __init__( + self, in_channels=1, out_channels=1, - feature_size=48, - use_checkpoint=use_checkpoint, - ) - - -def get_output(model, input): - out = model(input) - return torch.sigmoid(out) - - -def get_validation(model, val_inputs): - return model(val_inputs) + input_img_size=128, + use_checkpoint=True, + **kwargs, + ): + try: + super().__init__( + input_img_size, + in_channels=in_channels, + out_channels=out_channels, + feature_size=48, + use_checkpoint=use_checkpoint, + **kwargs, + ) + except TypeError as e: + logger.warning(f"Caught TypeError: {e}") + super().__init__( + input_img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + ) + + # def get_output(self, input): + # out = self(input) + # return torch.sigmoid(out) + + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py index 09de2a26..e6bbad55 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py @@ -2,28 +2,8 @@ from torch import nn -def get_weights_file(): - # model additionally trained on Mathis/Wyss mesoSPIM data - return "TRAILMAP_PyTorch.pth" - # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them - - -def get_net(): - return TRAILMAP(1, 1) - - -def get_output(model, input): - out = model(input) - - return out - - -def get_validation(model, val_inputs): - return model(val_inputs) - - class TRAILMAP(nn.Module): - def __init__(self, in_ch, out_ch): + def __init__(self, in_ch, out_ch, *args, **kwargs): super().__init__() self.conv0 = self.encoderBlock(in_ch, 32, 3) # input self.conv1 = self.encoderBlock(32, 64, 3) # l1 @@ -59,13 +39,12 @@ def forward(self, x): up8 = self.up8(torch.cat([up7, conv0], 1)) # l1 # print(up8.shape) - out = self.out(up8) + return self.out(up8) # print("out:") # print(out.shape) - return out def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -76,10 +55,9 @@ def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.ReLU(), nn.MaxPool3d(2), ) - return encode def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): - encode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -89,10 +67,9 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): nn.BatchNorm3d(out_ch), nn.ReLU(), ) - return encode def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): - decode = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), nn.BatchNorm3d(out_ch), nn.ReLU(), @@ -105,10 +82,25 @@ def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): out_ch, out_ch, kernel_size=kernel_size, stride=(2, 2, 2) ), ) - return decode def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): - out = nn.Sequential( + return nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), ) - return out + + +class TRAILMAP_(TRAILMAP): + use_default_training = True + weights_file = "TRAILMAP_PyTorch.pth" # model additionally trained on Mathis/Wyss mesoSPIM data + # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them + + def __init__(self, in_channels=1, out_channels=1, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + + # def get_output(model, input): + # out = model(input) + # + # return out + + # def get_validation(model, val_inputs): + # return model(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 0fc68d34..baf8635d 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,20 +1,30 @@ from napari_cellseg3d.code_models.models.unet.model import UNet3D +from napari_cellseg3d.utils import LOGGER +logger = LOGGER -def get_weights_file(): - # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) - return "TRAILMAP_MS_best_metric_epoch_26.pth" - - -def get_net(): - return UNet3D(1, 1) +class TRAILMAP_MS_(UNet3D): + use_default_training = True + weights_file = "TRAILMAP_MS_best_metric_epoch_26.pth" -def get_output(model, input): - out = model(input) - - return out - + # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) -def get_validation(model, val_inputs): - return model(val_inputs) + def __init__(self, in_channels=1, out_channels=1, **kwargs): + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError as e: + logger.warning(f"Caught TypeError: {e}") + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) + + # def get_output(self, input): + # out = self(input) + + # return out + # + # def get_validation(self, val_inputs): + # return self(val_inputs) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 0c854832..b082ccab 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -1,29 +1,16 @@ -from monai.inferers import sliding_window_inference from monai.networks.nets import VNet -def get_net(): - return VNet() +class VNet_(VNet): + use_default_training = True + weights_file = "VNet_40e.pth" - -def get_weights_file(): - return "VNet_40e.pth" - - -def get_output(model, input): - out = model(input) - return out - - -def get_validation(model, val_inputs): - roi_size = (64, 64, 64) - sw_batch_size = 1 - val_outputs = sliding_window_inference( - val_inputs, - roi_size, - sw_batch_size, - model, - mode="gaussian", - overlap=0.7, - ) - return val_outputs + def __init__(self, in_channels=1, out_channels=1, **kwargs): + try: + super().__init__( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + except TypeError: + super().__init__( + in_channels=in_channels, out_channels=out_channels + ) diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py new file mode 100644 index 00000000..a2fce724 --- /dev/null +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -0,0 +1,42 @@ +# local +from napari_cellseg3d.code_models.models.wnet.model import WNet_encoder + + +class WNet_(WNet_encoder): + use_default_training = False + weights_file = "wnet.pth" + + def __init__( + self, + in_channels=1, + out_channels=2, + # num_classes=2, + device="cpu", + **kwargs, + ): + super().__init__( + device=device, + in_channels=in_channels, + out_channels=out_channels, + # num_classes=num_classes, + ) + + # def train(self: T, mode: bool = True) -> T: + # raise NotImplementedError("Training not implemented for WNet") + + # def forward(self, x): + # """Forward ENCODER pass of the W-Net model. + # Done this way to allow inference on the encoder only when called by sliding_window_inference. + # """ + # return self.forward_encoder(x) + # # enc = self.forward_encoder(x) + # # return self.forward_decoder(enc) + + def load_state_dict(self, state_dict, strict=False): + """Load the model state dict for inference, without the decoder weights.""" + encoder_checkpoint = state_dict.copy() + for k in state_dict: + if k.startswith("decoder"): + encoder_checkpoint.pop(k) + # print(encoder_checkpoint.keys()) + super().load_state_dict(encoder_checkpoint, strict=strict) diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 5871c4a7..28f3a05b 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -2,35 +2,31 @@ from torch import nn -def get_weights_file(): - return "test.pth" - - class TestModel(nn.Module): - def __init__(self): + use_default_training = True + weights_file = "test.pth" + + def __init__(self, **kwargs): super().__init__() - self.linear = nn.Linear(1, 1) + self.linear = nn.Linear(8, 8) def forward(self, x): return self.linear(torch.tensor(x, requires_grad=True)) - def get_net(self): - return self + # def get_output(self, _, input): + # return input - def get_output(self, _, input): - return input + # def get_validation(self, val_inputs): + # return val_inputs - def get_validation(self, val_inputs): - return val_inputs +if __name__ == "__main__": + model = TestModel() + model.train() + model.zero_grad() + from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR -# if __name__ == "__main__": -# -# model = TestModel() -# model.train() -# model.zero_grad() -# from napari_cellseg3d.config import WEIGHTS_DIR -# torch.save( -# model.state_dict(), -# WEIGHTS_DIR + f"/{get_weights_file()}" -# ) + torch.save( + model.state_dict(), + PRETRAINED_WEIGHTS_DIR + f"/{TestModel.weights_file}", + ) diff --git a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json index cd0782fb..cde5e332 100644 --- a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json @@ -3,5 +3,6 @@ "SegResNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SegResNet.tar.gz", "VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet.tar.gz", "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/Swin64.tar.gz", + "WNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet.tar.gz", "test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz" } diff --git a/napari_cellseg3d/code_models/models/unet/buildingblocks.py b/napari_cellseg3d/code_models/models/unet/buildingblocks.py index 73913ab8..ce7d378f 100644 --- a/napari_cellseg3d/code_models/models/unet/buildingblocks.py +++ b/napari_cellseg3d/code_models/models/unet/buildingblocks.py @@ -422,8 +422,7 @@ def forward(self, encoder_features, x): def _joining(encoder_features, x, concat): if concat: return torch.cat((encoder_features, x), dim=1) - else: - return encoder_features + x + return encoder_features + x def create_encoders( diff --git a/napari_cellseg3d/code_models/models/wnet/__init__.py b/napari_cellseg3d/code_models/models/wnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py new file mode 100644 index 00000000..5ef726b6 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -0,0 +1,240 @@ +""" +Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. +The model performs unsupervised segmentation of 3D images. +""" + +from typing import List + +import torch +import torch.nn as nn + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Xide Xia", + "Brian Kulis", +] + + +class WNet_encoder(nn.Module): + """WNet with encoder only.""" + + def __init__( + self, + device, + in_channels=1, + out_channels=2 + # num_classes=2 + ): + super().__init__() + self.device = device + self.encoder = UNet( + in_channels=in_channels, + out_channels=out_channels, + encoder=True, + ) + + def forward(self, x): + """Forward pass of the W-Net model.""" + return self.encoder(x) + + +class WNet(nn.Module): + """Implementation of a 3D W-Net model, based on the 2D version from https://arxiv.org/abs/1711.08506. + The model performs unsupervised segmentation of 3D images. + It first encodes the input image into a latent space using the U-Net UEncoder, then decodes it back to the original image using the U-Net UDecoder. + """ + + def __init__( + self, + in_channels=1, + out_channels=2, + num_classes=2, + dropout=0.65, + ): + super(WNet, self).__init__() + self.encoder = UNet( + in_channels, num_classes, encoder=True, dropout=dropout + ) + self.decoder = UNet( + num_classes, out_channels, encoder=False, dropout=dropout + ) + + def forward(self, x): + """Forward pass of the W-Net model.""" + enc = self.forward_encoder(x) + return enc, self.forward_decoder(enc) + + def forward_encoder(self, x): + """Forward pass of the encoder part of the W-Net model.""" + return self.encoder(x) + + def forward_decoder(self, enc): + """Forward pass of the decoder part of the W-Net model.""" + return self.decoder(enc) + + +class UNet(nn.Module): + """Half of the W-Net model, based on the U-Net architecture.""" + + def __init__( + self, + # device, + in_channels: int, + out_channels: int, + channels: List[int] = None, + encoder: bool = True, + dropout: float = 0.65, + ): + if channels is None: + channels = [64, 128, 256, 512, 1024] + if len(channels) != 5: + raise ValueError( + "Channels must be a list of channels in the form: [64, 128, 256, 512, 1024]" + ) + super(UNet, self).__init__() + # self.device = device + self.channels = channels + self.max_pool = nn.MaxPool3d(2) + self.in_b = InBlock(in_channels, self.channels[0], dropout=dropout) + self.conv1 = Block(channels[0], self.channels[1], dropout=dropout) + self.conv2 = Block(channels[1], self.channels[2], dropout=dropout) + self.conv3 = Block(channels[2], self.channels[3], dropout=dropout) + self.bot = Block(channels[3], self.channels[4], dropout=dropout) + self.deconv1 = Block(channels[4], self.channels[3], dropout=dropout) + self.conv_trans1 = nn.ConvTranspose3d( + self.channels[4], self.channels[3], 2, stride=2 + ) + self.deconv2 = Block(channels[3], self.channels[2], dropout=dropout) + self.conv_trans2 = nn.ConvTranspose3d( + self.channels[3], self.channels[2], 2, stride=2 + ) + self.deconv3 = Block(channels[2], self.channels[1], dropout=dropout) + self.conv_trans3 = nn.ConvTranspose3d( + self.channels[2], self.channels[1], 2, stride=2 + ) + self.out_b = OutBlock(channels[1], out_channels, dropout=dropout) + self.conv_trans_out = nn.ConvTranspose3d( + self.channels[1], self.channels[0], 2, stride=2 + ) + + self.sm = nn.Softmax(dim=1) + self.encoder = encoder + + def forward(self, x): + """Forward pass of the U-Net model.""" + in_b = self.in_b(x) + c1 = self.conv1(self.max_pool(in_b)) + c2 = self.conv2(self.max_pool(c1)) + c3 = self.conv3(self.max_pool(c2)) + x = self.bot(self.max_pool(c3)) + x = self.deconv1( + torch.cat( + [ + c3, + self.conv_trans1(x), + ], + dim=1, + ) + ) + x = self.deconv2( + torch.cat( + [ + c2, + self.conv_trans2(x), + ], + dim=1, + ) + ) + x = self.deconv3( + torch.cat( + [ + c1, + self.conv_trans3(x), + ], + dim=1, + ) + ) + x = self.out_b( + torch.cat( + [ + in_b, + self.conv_trans_out(x), + ], + dim=1, + ) + ) + if self.encoder: + x = self.sm(x) + return x + + +class InBlock(nn.Module): + """Input block of the U-Net architecture.""" + + def __init__(self, in_channels, out_channels, dropout=0.65): + super(InBlock, self).__init__() + # self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, out_channels, 3, padding=1), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.BatchNorm3d(out_channels), + nn.Conv3d(out_channels, out_channels, 3, padding=1), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.BatchNorm3d(out_channels), + ) + + def forward(self, x): + """Forward pass of the input block.""" + return self.module(x) + + +class Block(nn.Module): + """Basic block of the U-Net architecture.""" + + def __init__(self, in_channels, out_channels, dropout=0.65): + super(Block, self).__init__() + # self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, in_channels, 3, padding=1), + nn.Conv3d(in_channels, out_channels, 1), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.BatchNorm3d(out_channels), + nn.Conv3d(out_channels, out_channels, 3, padding=1), + nn.Conv3d(out_channels, out_channels, 1), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.BatchNorm3d(out_channels), + ) + + def forward(self, x): + """Forward pass of the basic block.""" + return self.module(x) + + +class OutBlock(nn.Module): + """Output block of the U-Net architecture.""" + + def __init__(self, in_channels, out_channels, dropout=0.65): + super(OutBlock, self).__init__() + # self.device = device + self.module = nn.Sequential( + nn.Conv3d(in_channels, 64, 3, padding=1), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.BatchNorm3d(64), + nn.Conv3d(64, 64, 3, padding=1), + nn.ReLU(), + nn.Dropout(p=dropout), + nn.BatchNorm3d(64), + nn.Conv3d(64, out_channels, 1), + ) + + def forward(self, x): + """Forward pass of the output block.""" + return self.module(x) diff --git a/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py new file mode 100644 index 00000000..e0f92ff7 --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py @@ -0,0 +1,225 @@ +""" +Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. +The implementation was adapted and approximated to reduce computational and memory cost. +This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. +""" + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.stats import norm + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" +__credits__ = [ + "Yves Paychère", + "Colin Hofmann", + "Cyril Achard", + "Xide Xia", + "Brian Kulis", + "Jianbo Shi", + "Jitendra Malik", + "Frank Odom", +] + + +class SoftNCutsLoss(nn.Module): + """Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. + + Args: + data_shape (H, W, D): shape of the images as a tuple. + intensity_sigma (scalar): scale of the gaussian kernel of pixels brightness. + spatial_sigma (scalar): scale of the gaussian kernel of pixels spacial distance. + radius (scalar): radius of pixels for which we compute the weights + """ + + def __init__( + self, data_shape, device, intensity_sigma, spatial_sigma, radius=None + ): + super(SoftNCutsLoss, self).__init__() + self.intensity_sigma = intensity_sigma + self.spatial_sigma = spatial_sigma + self.radius = radius + self.H = data_shape[0] + self.W = data_shape[1] + self.D = data_shape[2] + self.device = device + + if self.radius is None: + self.radius = min( + max(5, math.ceil(min(self.H, self.W, self.D) / 20)), + self.H, + self.W, + self.D, + ) + print(f"Radius set to {self.radius}") + + def forward(self, labels, inputs): + """Forward pass of the Soft N-Cuts loss. + + Args: + labels (torch.Tensor): Tensor of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. + inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. + + Returns: + The Soft N-Cuts loss of shape (N,). + """ + # inputs.shape[0] + # inputs.shape[1] + K = labels.shape[1] + + labels.to(self.device) + inputs.to(self.device) + + loss = 0 + + kernel = self.gaussian_kernel(self.radius, self.spatial_sigma).to( + self.device + ) + + for k in range(K): + # Compute the average pixel value for this class, and the difference from each pixel + class_probs = labels[:, k].unsqueeze(1) + class_mean = torch.mean( + inputs * class_probs, dim=(2, 3, 4), keepdim=True + ) / torch.add( + torch.mean(class_probs, dim=(2, 3, 4), keepdim=True), 1e-5 + ) + diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1) + + # Weight the loss by the difference from the class average. + weights = torch.exp( + diff.pow(2).mul(-1 / self.intensity_sigma**2) + ) + + numerator = torch.sum( + class_probs + * F.conv3d(class_probs * weights, kernel, padding=self.radius), + dim=(1, 2, 3, 4), + ) + denominator = torch.sum( + class_probs * F.conv3d(weights, kernel, padding=self.radius), + dim=(1, 2, 3, 4), + ) + loss += nn.L1Loss()( + numerator / torch.add(denominator, 1e-6), + torch.zeros_like(numerator), + ) + + return K - loss + + def gaussian_kernel(self, radius, sigma): + """Computes the Gaussian kernel. + + Args: + radius (int): The radius of the kernel. + sigma (float): The standard deviation of the Gaussian distribution. + + Returns: + The Gaussian kernel of shape (1, 1, 2*radius+1, 2*radius+1, 2*radius+1). + """ + x_2 = np.linspace(-radius, radius, 2 * radius + 1) ** 2 + dist = ( + np.sqrt( + x_2.reshape(-1, 1, 1) + + x_2.reshape(1, -1, 1) + + x_2.reshape(1, 1, -1) + ) + / sigma + ) + kernel = norm.pdf(dist) / norm.pdf(0) + kernel = torch.from_numpy(kernel.astype(np.float32)) + return kernel.view( + (1, 1, kernel.shape[0], kernel.shape[1], kernel.shape[2]) + ) + + def get_distances(self): + """Precompute the spatial distance of the pixels for the weights calculation, to avoid recomputing it at each iteration. + + Returns: + distances (dict): for each pixel index, we get the distances to the pixels in a radius around it. + """ + + distances = dict() + indexes = np.array( + [ + (i, j, k) + for i in range(self.H) + for j in range(self.W) + for k in range(self.D) + ] + ) + + for i in indexes: + iTuple = (i[0], i[1], i[2]) + distances[iTuple] = dict() + + sliceD = indexes[ + i[0] * self.H + + i[1] * self.W + + max(0, i[2] - self.radius) : i[0] * self.H + + i[1] * self.W + + min(self.D, i[2] + self.radius) + ] + sliceW = indexes[ + i[0] * self.H + + max(0, i[1] - self.radius) * self.W + + i[2] : i[0] * self.H + + min(self.W, i[1] + self.radius) * self.W + + i[2] : self.D + ] + sliceH = indexes[ + max(0, i[0] - self.radius) * self.H + + i[1] * self.W + + i[2] : min(self.H, i[0] + self.radius) * self.H + + i[1] * self.W + + i[2] : self.D * self.W + ] + + for j in np.concatenate((sliceD, sliceW, sliceH)): + jTuple = (j[0], j[1], j[2]) + distance = np.linalg.norm(i - j) + if distance > self.radius: + continue + distance = math.exp( + -(distance**2) / (self.spatial_sigma**2) + ) + + if jTuple not in distances: + distances[iTuple][jTuple] = distance + + return distances, indexes + + def get_weights(self, inputs): + """Computes the weights matrix for the Soft N-Cuts loss. + + Args: + inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. + + Returns: + list: List of the weights dict for each image in the batch. + """ + + # Compute the brightness distance of the pixels + flatted_inputs = inputs.view( + inputs.shape[0], inputs.shape[1], -1 + ) # (N, C, H*W*D) + I_diff = torch.subtract( + flatted_inputs.unsqueeze(3), flatted_inputs.unsqueeze(2) + ) # (N, C, H*W*D, H*W*D) + masked_I_diff = torch.mul(I_diff, self.mask) # (N, C, H*W*D, H*W*D) + squared_I_diff = torch.pow(masked_I_diff, 2) # (N, C, H*W*D, H*W*D) + + W_I = torch.exp( + torch.neg(torch.div(squared_I_diff, self.intensity_sigma)) + ) # (N, C, H*W*D, H*W*D) + W_I = torch.mul(W_I, self.mask) # (N, C, H*W*D, H*W*D) + + # Get the spatial distance of the pixels + unsqueezed_W_X = self.W_X.view( + 1, 1, self.W_X.shape[0], self.W_X.shape[1] + ) # (1, 1, H*W*D, H*W*D) + + return torch.mul(W_I, unsqueezed_W_X) # (N, C, H*W*D, H*W*D) diff --git a/napari_cellseg3d/code_models/models/wnet/train_wnet.py b/napari_cellseg3d/code_models/models/wnet/train_wnet.py new file mode 100644 index 00000000..61d8959a --- /dev/null +++ b/napari_cellseg3d/code_models/models/wnet/train_wnet.py @@ -0,0 +1,1008 @@ +""" +This file contains the code to train the WNet model. +""" +# import napari +import glob +import time +from pathlib import Path +from typing import Union +from warnings import warn + +import numpy as np +import tifffile as tiff +import torch +import torch.nn as nn + +# MONAI +from monai.data import ( + CacheDataset, + DataLoader, + PatchDataset, + pad_list_data_collate, +) +from monai.data.meta_obj import set_track_meta +from monai.metrics import DiceMetric +from monai.transforms import ( + AsDiscrete, + Compose, + EnsureChannelFirst, + EnsureChannelFirstd, + EnsureTyped, + LoadImaged, + Orientationd, + RandFlipd, + RandRotate90d, + RandShiftIntensityd, + RandSpatialCropSamplesd, + ScaleIntensityRanged, + SpatialPadd, + ToTensor, +) +from monai.utils.misc import set_determinism + +# local +from napari_cellseg3d.code_models.models.wnet.model import WNet +from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss +from napari_cellseg3d.utils import LOGGER as logger +from napari_cellseg3d.utils import dice_coeff, get_padding_dim + +try: + import wandb + + WANDB_INSTALLED = True +except ImportError: + warn( + "wandb not installed, wandb config will not be taken into account", + stacklevel=1, + ) + WANDB_INSTALLED = False + +__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" + + +########################## +# Utils functions # +########################## + + +def create_dataset_dict(volume_directory, label_directory): + """Creates data dictionary for MONAI transforms and training.""" + images_filepaths = sorted( + [str(file) for file in Path(volume_directory).glob("*.tif")] + ) + + labels_filepaths = sorted( + [str(file) for file in Path(label_directory).glob("*.tif")] + ) + if len(images_filepaths) == 0 or len(labels_filepaths) == 0: + raise ValueError( + f"Data folders are empty \n{volume_directory} \n{label_directory}" + ) + + logger.info("Images :") + for file in images_filepaths: + logger.info(Path(file).stem) + logger.info("*" * 10) + logger.info("Labels :") + for file in labels_filepaths: + logger.info(Path(file).stem) + try: + data_dicts = [ + {"image": image_name, "label": label_name} + for image_name, label_name in zip( + images_filepaths, labels_filepaths + ) + ] + except ValueError as e: + raise ValueError( + f"Number of images and labels does not match : \n{volume_directory} \n{label_directory}" + ) from e + # print(f"Loaded eval image: {data_dicts}") + return data_dicts + + +def create_dataset_dict_no_labs(volume_directory): + """Creates unsupervised data dictionary for MONAI transforms and training.""" + images_filepaths = sorted(glob.glob(str(Path(volume_directory) / "*.tif"))) + if len(images_filepaths) == 0: + raise ValueError(f"Data folder {volume_directory} is empty") + + logger.info("Images :") + for file in images_filepaths: + logger.info(Path(file).stem) + logger.info("*" * 10) + + return [{"image": image_name} for image_name in images_filepaths] + + +def remap_image( + image: Union[np.ndarray, torch.Tensor], new_max=100, new_min=0 +): + """Normalizes a numpy array or Tensor using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image * (new_max - new_min) + new_min + # image = set_quantile_to_value(image) + return image.reshape(shape) + + +################################ +# Config & WANDB # +################################ + + +class Config: + def __init__(self): + # WNet + self.in_channels = 1 + self.out_channels = 1 + self.num_classes = 2 + self.dropout = 0.65 + self.use_clipping = False + self.clipping = 1 + + self.lr = 1e-6 + self.scheduler = "None" # "CosineAnnealingLR" # "ReduceLROnPlateau" + self.weight_decay = 0.01 # None + + self.intensity_sigma = 1 + self.spatial_sigma = 4 + self.radius = 2 # yields to a radius depending on the data shape + + self.n_cuts_weight = 0.5 + self.reconstruction_loss = "MSE" # "BCE" + self.rec_loss_weight = 0.5 / 100 + + self.num_epochs = 100 + self.val_interval = 5 + self.batch_size = 2 + self.num_workers = 4 + + # CRF + self.sa = 50 # 10 + self.sb = 20 + self.sg = 1 + self.w1 = 50 # 10 + self.w2 = 20 + self.n_iter = 5 + + # Data + self.train_volume_directory = "./../dataset/VIP_full" + self.eval_volume_directory = "./../dataset/VIP_cropped/eval/" + self.normalize_input = True + self.normalizing_function = remap_image # normalize_quantile + self.use_patch = False + self.patch_size = (64, 64, 64) + self.num_patches = 30 + self.eval_num_patches = 20 + self.do_augmentation = True + self.parallel = False + + self.save_model = True + self.save_model_path = ( + r"./../results/new_model/wnet_new_model_all_data_3class.pth" + ) + # self.save_losses_path = ( + # r"./../results/new_model/wnet_new_model_all_data_3class.pkl" + # ) + self.save_every = 5 + self.weights_path = None + + +c = Config() +############### +# Scheduler config +############### +schedulers = { + "ReduceLROnPlateau": { + "factor": 0.5, + "patience": 50, + }, + "CosineAnnealingLR": { + "T_max": 25000, + "eta_min": 1e-8, + }, + "CosineAnnealingWarmRestarts": { + "T_0": 50000, + "eta_min": 1e-8, + "T_mult": 1, + }, + "CyclicLR": { + "base_lr": 2e-7, + "max_lr": 2e-4, + "step_size_up": 250, + "mode": "triangular", + }, +} + +############### +# WANDB_CONFIG +############### +WANDB_MODE = "disabled" +# WANDB_MODE = "online" + +WANDB_CONFIG = { + # data setting + "num_workers": c.num_workers, + "normalize": c.normalize_input, + "use_patch": c.use_patch, + "patch_size": c.patch_size, + "num_patches": c.num_patches, + "eval_num_patches": c.eval_num_patches, + "do_augmentation": c.do_augmentation, + "model_save_path": c.save_model_path, + # train setting + "batch_size": c.batch_size, + "learning_rate": c.lr, + "weight_decay": c.weight_decay, + "scheduler": { + "name": c.scheduler, + "ReduceLROnPlateau_config": { + "factor": schedulers["ReduceLROnPlateau"]["factor"], + "patience": schedulers["ReduceLROnPlateau"]["patience"], + }, + "CosineAnnealingLR_config": { + "T_max": schedulers["CosineAnnealingLR"]["T_max"], + "eta_min": schedulers["CosineAnnealingLR"]["eta_min"], + }, + "CosineAnnealingWarmRestarts_config": { + "T_0": schedulers["CosineAnnealingWarmRestarts"]["T_0"], + "eta_min": schedulers["CosineAnnealingWarmRestarts"]["eta_min"], + "T_mult": schedulers["CosineAnnealingWarmRestarts"]["T_mult"], + }, + "CyclicLR_config": { + "base_lr": schedulers["CyclicLR"]["base_lr"], + "max_lr": schedulers["CyclicLR"]["max_lr"], + "step_size_up": schedulers["CyclicLR"]["step_size_up"], + "mode": schedulers["CyclicLR"]["mode"], + }, + }, + "max_epochs": c.num_epochs, + "save_every": c.save_every, + "val_interval": c.val_interval, + # loss + "reconstruction_loss": c.reconstruction_loss, + "loss weights": { + "n_cuts_weight": c.n_cuts_weight, + "rec_loss_weight": c.rec_loss_weight, + }, + "loss_params": { + "intensity_sigma": c.intensity_sigma, + "spatial_sigma": c.spatial_sigma, + "radius": c.radius, + }, + # model + "model_type": "wnet", + "model_params": { + "in_channels": c.in_channels, + "out_channels": c.out_channels, + "num_classes": c.num_classes, + "dropout": c.dropout, + "use_clipping": c.use_clipping, + "clipping_value": c.clipping, + }, + # CRF + "crf_params": { + "sa": c.sa, + "sb": c.sb, + "sg": c.sg, + "w1": c.w1, + "w2": c.w2, + "n_iter": c.n_iter, + }, +} + + +def train(weights_path=None, train_config=None): + if train_config is None: + config = Config() + ############## + # disable metadata tracking + set_track_meta(False) + ############## + if WANDB_INSTALLED: + wandb.init( + config=WANDB_CONFIG, project="WNet-benchmark", mode=WANDB_MODE + ) + + set_determinism(seed=34936339) # use default seed from NP_MAX + torch.use_deterministic_algorithms(True, warn_only=True) + + config = train_config + normalize_function = config.normalizing_function + CUDA = torch.cuda.is_available() + device = torch.device("cuda" if CUDA else "cpu") + + print(f"Using device: {device}") + + print("Config:") + [print(a) for a in config.__dict__.items()] + + print("Initializing training...") + print("Getting the data") + + if config.use_patch: + (data_shape, dataset) = get_patch_dataset(config) + else: + (data_shape, dataset) = get_dataset(config) + transform = Compose( + [ + ToTensor(), + EnsureChannelFirst(channel_dim=0), + ] + ) + dataset = [transform(im) for im in dataset] + for data in dataset: + print(f"data shape: {data.shape}") + break + + dataloader = DataLoader( + dataset, + batch_size=config.batch_size, + shuffle=True, + num_workers=config.num_workers, + collate_fn=pad_list_data_collate, + ) + + if config.eval_volume_directory is not None: + eval_dataset = get_patch_eval_dataset(config) + + eval_dataloader = DataLoader( + eval_dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=config.num_workers, + collate_fn=pad_list_data_collate, + ) + + dice_metric = DiceMetric( + include_background=False, reduction="mean", get_not_nans=False + ) + ################################################### + # Training the model # + ################################################### + print("Initializing the model:") + + print("- getting the model") + # Initialize the model + model = WNet( + in_channels=config.in_channels, + out_channels=config.out_channels, + num_classes=config.num_classes, + dropout=config.dropout, + ) + model = ( + nn.DataParallel(model).cuda() if CUDA and config.parallel else model + ) + model.to(device) + + if config.use_clipping: + for p in model.parameters(): + p.register_hook( + lambda grad: torch.clamp( + grad, min=-config.clipping, max=config.clipping + ) + ) + + if WANDB_INSTALLED: + wandb.watch(model, log_freq=100) + + if weights_path is not None: + model.load_state_dict(torch.load(weights_path, map_location=device)) + + print("- getting the optimizers") + # Initialize the optimizers + if config.weight_decay is not None: + decay = config.weight_decay + optimizer = torch.optim.Adam( + model.parameters(), lr=config.lr, weight_decay=decay + ) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) + + print("- getting the loss functions") + # Initialize the Ncuts loss function + criterionE = SoftNCutsLoss( + data_shape=data_shape, + device=device, + intensity_sigma=config.intensity_sigma, + spatial_sigma=config.spatial_sigma, + radius=config.radius, + ) + + if config.reconstruction_loss == "MSE": + criterionW = nn.MSELoss() + elif config.reconstruction_loss == "BCE": + criterionW = nn.BCELoss() + else: + raise ValueError( + f"Unknown reconstruction loss : {config.reconstruction_loss} not supported" + ) + + print("- getting the learning rate schedulers") + # Initialize the learning rate schedulers + scheduler = get_scheduler(config, optimizer) + # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + # optimizer, mode="min", factor=0.5, patience=10, verbose=True + # ) + model.train() + + print("Ready") + print("Training the model") + print("*" * 50) + + startTime = time.time() + ncuts_losses = [] + rec_losses = [] + total_losses = [] + best_dice = -1 + best_dice_epoch = -1 + + # Train the model + for epoch in range(config.num_epochs): + print(f"Epoch {epoch + 1} of {config.num_epochs}") + + epoch_ncuts_loss = 0 + epoch_rec_loss = 0 + epoch_loss = 0 + + for _i, batch in enumerate(dataloader): + # raise NotImplementedError("testing") + if config.use_patch: + image = batch["image"].to(device) + else: + image = batch.to(device) + if config.batch_size == 1: + image = image.unsqueeze(0) + else: + image = image.unsqueeze(0) + image = torch.swapaxes(image, 0, 1) + + # Forward pass + enc = model.forward_encoder(image) + # out = model.forward(image) + + # Compute the Ncuts loss + Ncuts = criterionE(enc, image) + epoch_ncuts_loss += Ncuts.item() + if WANDB_INSTALLED: + wandb.log({"Ncuts loss": Ncuts.item()}) + + # Forward pass + enc, dec = model(image) + + # Compute the reconstruction loss + if isinstance(criterionW, nn.MSELoss): + reconstruction_loss = criterionW(dec, image) + elif isinstance(criterionW, nn.BCELoss): + reconstruction_loss = criterionW( + torch.sigmoid(dec), + remap_image(image, new_max=1), + ) + + epoch_rec_loss += reconstruction_loss.item() + if WANDB_INSTALLED: + wandb.log({"Reconstruction loss": reconstruction_loss.item()}) + + # Backward pass for the reconstruction loss + optimizer.zero_grad() + alpha = config.n_cuts_weight + beta = config.rec_loss_weight + + loss = alpha * Ncuts + beta * reconstruction_loss + epoch_loss += loss.item() + if WANDB_INSTALLED: + wandb.log({"Sum of losses": loss.item()}) + loss.backward(loss) + optimizer.step() + + if config.scheduler == "CosineAnnealingWarmRestarts": + scheduler.step(epoch + _i / len(dataloader)) + if ( + config.scheduler == "CosineAnnealingLR" + or config.scheduler == "CyclicLR" + ): + scheduler.step() + + ncuts_losses.append(epoch_ncuts_loss / len(dataloader)) + rec_losses.append(epoch_rec_loss / len(dataloader)) + total_losses.append(epoch_loss / len(dataloader)) + + if WANDB_INSTALLED: + wandb.log({"Ncuts loss_epoch": ncuts_losses[-1]}) + wandb.log({"Reconstruction loss_epoch": rec_losses[-1]}) + wandb.log({"Sum of losses_epoch": total_losses[-1]}) + # wandb.log({"epoch": epoch}) + # wandb.log({"learning_rate model": optimizerW.param_groups[0]["lr"]}) + # wandb.log({"learning_rate encoder": optimizerE.param_groups[0]["lr"]}) + wandb.log({"learning_rate model": optimizer.param_groups[0]["lr"]}) + + print("Ncuts loss: ", ncuts_losses[-1]) + if epoch > 0: + print( + "Ncuts loss difference: ", + ncuts_losses[-1] - ncuts_losses[-2], + ) + print("Reconstruction loss: ", rec_losses[-1]) + if epoch > 0: + print( + "Reconstruction loss difference: ", + rec_losses[-1] - rec_losses[-2], + ) + print("Sum of losses: ", total_losses[-1]) + if epoch > 0: + print( + "Sum of losses difference: ", + total_losses[-1] - total_losses[-2], + ) + + # Update the learning rate + if config.scheduler == "ReduceLROnPlateau": + # schedulerE.step(epoch_ncuts_loss) + # schedulerW.step(epoch_rec_loss) + scheduler.step(epoch_rec_loss) + if ( + config.eval_volume_directory is not None + and (epoch + 1) % config.val_interval == 0 + ): + model.eval() + print("Validating...") + with torch.no_grad(): + for _k, val_data in enumerate(eval_dataloader): + val_inputs, val_labels = ( + val_data["image"].to(device), + val_data["label"].to(device), + ) + + # normalize val_inputs across channels + if config.normalize_input: + for i in range(val_inputs.shape[0]): + for j in range(val_inputs.shape[1]): + val_inputs[i][j] = normalize_function( + val_inputs[i][j] + ) + + val_outputs = model.forward_encoder(val_inputs) + val_outputs = AsDiscrete(threshold=0.5)(val_outputs) + + # compute metric for current iteration + for channel in range(val_outputs.shape[1]): + max_dice_channel = torch.argmax( + torch.Tensor( + [ + dice_coeff( + y_pred=val_outputs[ + :, + channel : (channel + 1), + :, + :, + :, + ], + y_true=val_labels, + ) + ] + ) + ) + + dice_metric( + y_pred=val_outputs[ + :, + max_dice_channel : (max_dice_channel + 1), + :, + :, + :, + ], + y=val_labels, + ) + # if plot_val_input: # only once + # logged_image = val_inputs.detach().cpu().numpy() + # logged_image = np.swapaxes(logged_image, 2, 4) + # logged_image = logged_image[0, :, 32, :, :] + # images = wandb.Image( + # logged_image, caption="Validation input" + # ) + # + # wandb.log({"val/input": images}) + # plot_val_input = False + + # if k == 2 and (30 <= epoch <= 50 or epoch % 100 == 0): + # logged_image = val_outputs.detach().cpu().numpy() + # logged_image = np.swapaxes(logged_image, 2, 4) + # logged_image = logged_image[ + # 0, max_dice_channel, 32, :, : + # ] + # images = wandb.Image( + # logged_image, caption="Validation output" + # ) + # + # wandb.log({"val/output": images}) + # dice_metric(y_pred=val_outputs[:, 2:, :,:,:], y=val_labels) + # dice_metric(y_pred=val_outputs[:, 1:, :, :, :], y=val_labels) + + # import napari + # view = napari.Viewer() + # view.add_image(val_inputs.cpu().numpy(), name="input") + # view.add_image(val_labels.cpu().numpy(), name="label") + # vis_out = np.array( + # [i.detach().cpu().numpy() for i in val_outputs], + # dtype=np.float32, + # ) + # crf_out = np.array( + # [i.detach().cpu().numpy() for i in crf_outputs], + # dtype=np.float32, + # ) + # view.add_image(vis_out, name="output") + # view.add_image(crf_out, name="crf_output") + # napari.run() + + # aggregate the final mean dice result + metric = dice_metric.aggregate().item() + print("Validation Dice score: ", metric) + if best_dice < metric < 2: + best_dice = metric + best_dice_epoch = epoch + 1 + if config.save_model: + save_best_path = Path(config.save_model_path).parents[ + 0 + ] + save_best_path.mkdir(parents=True, exist_ok=True) + save_best_name = Path(config.save_model_path).stem + save_path = ( + str(save_best_path / save_best_name) + + "_best_metric.pth" + ) + print(f"Saving new best model to {save_path}") + torch.save(model.state_dict(), save_path) + + if WANDB_INSTALLED: + # log validation dice score for each validation round + wandb.log({"val/dice_metric": metric}) + + # reset the status for next validation round + dice_metric.reset() + + print( + "ETA: ", + (time.time() - startTime) + * (config.num_epochs / (epoch + 1) - 1) + / 60, + "minutes", + ) + print("-" * 20) + + # Save the model + if config.save_model and epoch % config.save_every == 0: + torch.save(model.state_dict(), config.save_model_path) + # with open(config.save_losses_path, "wb") as f: + # pickle.dump((ncuts_losses, rec_losses), f) + + print("Training finished") + print(f"Best dice metric : {best_dice}") + if WANDB_INSTALLED and config.eval_volume_directory is not None: + wandb.log( + { + "best_dice_metric": best_dice, + "best_metric_epoch": best_dice_epoch, + } + ) + print("*" * 50) + + # Save the model + if config.save_model: + print("Saving the model to: ", config.save_model_path) + torch.save(model.state_dict(), config.save_model_path) + # with open(config.save_losses_path, "wb") as f: + # pickle.dump((ncuts_losses, rec_losses), f) + if WANDB_INSTALLED: + model_artifact = wandb.Artifact( + "WNet", + type="model", + description="WNet benchmark", + metadata=dict(WANDB_CONFIG), + ) + model_artifact.add_file(config.save_model_path) + wandb.log_artifact(model_artifact) + + return ncuts_losses, rec_losses, model + + +def get_dataset(config): + """Creates a Dataset from the original data using the tifffile library + + Args: + config (Config): The configuration object + + Returns: + (tuple): A tuple containing the shape of the data and the dataset + """ + train_files = create_dataset_dict_no_labs( + volume_directory=config.train_volume_directory + ) + train_files = [d.get("image") for d in train_files] + volumes = tiff.imread(train_files).astype(np.float32) + volume_shape = volumes.shape + + if config.normalize_input: + volumes = np.array( + [ + # mad_normalization(volume) + config.normalizing_function(volume) + for volume in volumes + ] + ) + # mean = volumes.mean(axis=0) + # std = volumes.std(axis=0) + # volumes = (volumes - mean) / std + # print("NORMALIZED VOLUMES") + # print(volumes.shape) + # [print("MIN MAX", volume.flatten().min(), volume.flatten().max()) for volume in volumes] + # print(volumes.mean(axis=0), volumes.std(axis=0)) + + dataset = CacheDataset(data=volumes) + + return (volume_shape, dataset) + + # train_files = create_dataset_dict_no_labs( + # volume_directory=config.train_volume_directory + # ) + # train_files = [d.get("image") for d in train_files] + # volumes = [] + # for file in train_files: + # image = tiff.imread(file).astype(np.float32) + # image = np.expand_dims(image, axis=0) # add channel dimension + # volumes.append(image) + # # volumes = tiff.imread(train_files).astype(np.float32) + # volume_shape = volumes[0].shape + # # print(volume_shape) + # + # if config.do_augmentation: + # augmentation = Compose( + # [ + # ScaleIntensityRange( + # a_min=0, + # a_max=2000, + # b_min=0.0, + # b_max=1.0, + # clip=True, + # ), + # RandShiftIntensity(offsets=0.1, prob=0.5), + # RandFlip(spatial_axis=[1], prob=0.5), + # RandFlip(spatial_axis=[2], prob=0.5), + # RandRotate90(prob=0.1, max_k=3), + # ] + # ) + # else: + # augmentation = None + # + # dataset = CacheDataset(data=np.array(volumes), transform=augmentation) + # + # return (volume_shape, dataset) + + +def get_patch_dataset(config): + """Creates a Dataset from the original data using the tifffile library + + Args: + config (Config): The configuration object + + Returns: + (tuple): A tuple containing the shape of the data and the dataset + """ + + train_files = create_dataset_dict_no_labs( + volume_directory=config.train_volume_directory + ) + + patch_func = Compose( + [ + LoadImaged(keys=["image"], image_only=True), + EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"), + RandSpatialCropSamplesd( + keys=["image"], + roi_size=( + config.patch_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=config.num_patches, + ), + Orientationd(keys=["image"], axcodes="PLI"), + SpatialPadd( + keys=["image"], + spatial_size=(get_padding_dim(config.patch_size)), + ), + EnsureTyped(keys=["image"]), + ] + ) + + train_transforms = Compose( + [ + ScaleIntensityRanged( + keys=["image"], + a_min=0, + a_max=2000, + b_min=0.0, + b_max=1.0, + clip=True, + ), + RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), + RandFlipd(keys=["image"], spatial_axis=[1], prob=0.5), + RandFlipd(keys=["image"], spatial_axis=[2], prob=0.5), + RandRotate90d(keys=["image"], prob=0.1, max_k=3), + EnsureTyped(keys=["image"]), + ] + ) + + dataset = PatchDataset( + data=train_files, + samples_per_image=config.num_patches, + patch_func=patch_func, + transform=train_transforms, + ) + + return config.patch_size, dataset + + +def get_patch_eval_dataset(config): + eval_files = create_dataset_dict( + volume_directory=config.eval_volume_directory + "/vol", + label_directory=config.eval_volume_directory + "/lab", + ) + + patch_func = Compose( + [ + LoadImaged(keys=["image", "label"], image_only=True), + EnsureChannelFirstd( + keys=["image", "label"], channel_dim="no_channel" + ), + # NormalizeIntensityd(keys=["image"]) if config.normalize_input else lambda x: x, + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + config.patch_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=config.eval_num_patches, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=(get_padding_dim(config.patch_size)), + ), + EnsureTyped(keys=["image", "label"]), + ] + ) + + eval_transforms = Compose( + [ + EnsureTyped(keys=["image", "label"]), + ] + ) + + return PatchDataset( + data=eval_files, + samples_per_image=config.eval_num_patches, + patch_func=patch_func, + transform=eval_transforms, + ) + + +def get_dataset_monai(config): + """Creates a Dataset applying some transforms/augmentation on the data using the MONAI library + + Args: + config (Config): The configuration object + + Returns: + (tuple): A tuple containing the shape of the data and the dataset + """ + train_files = create_dataset_dict_no_labs( + volume_directory=config.train_volume_directory + ) + # print(train_files) + # print(len(train_files)) + # print(train_files[0]) + first_volume = LoadImaged(keys=["image"])(train_files[0]) + first_volume_shape = first_volume["image"].shape + + # Transforms to be applied to each volume + load_single_images = Compose( + [ + LoadImaged(keys=["image"]), + EnsureChannelFirstd(keys=["image"]), + Orientationd(keys=["image"], axcodes="PLI"), + SpatialPadd( + keys=["image"], + spatial_size=(get_padding_dim(first_volume_shape)), + ), + EnsureTyped(keys=["image"]), + ] + ) + + if config.do_augmentation: + train_transforms = Compose( + [ + ScaleIntensityRanged( + keys=["image"], + a_min=0, + a_max=2000, + b_min=0.0, + b_max=1.0, + clip=True, + ), + RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), + RandFlipd(keys=["image"], spatial_axis=[1], prob=0.5), + RandFlipd(keys=["image"], spatial_axis=[2], prob=0.5), + RandRotate90d(keys=["image"], prob=0.1, max_k=3), + EnsureTyped(keys=["image"]), + ] + ) + else: + train_transforms = EnsureTyped(keys=["image"]) + + # Create the dataset + dataset = CacheDataset( + data=train_files, + transform=Compose(load_single_images, train_transforms), + ) + + return first_volume_shape, dataset + + +def get_scheduler(config, optimizer, verbose=False): + scheduler_name = config.scheduler + if scheduler_name == "None": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=100, + eta_min=config.lr - 1e-6, + verbose=verbose, + ) + + elif scheduler_name == "ReduceLROnPlateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + factor=schedulers["ReduceLROnPlateau"]["factor"], + patience=schedulers["ReduceLROnPlateau"]["patience"], + verbose=verbose, + ) + elif scheduler_name == "CosineAnnealingLR": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=schedulers["CosineAnnealingLR"]["T_max"], + eta_min=schedulers["CosineAnnealingLR"]["eta_min"], + verbose=verbose, + ) + elif scheduler_name == "CosineAnnealingWarmRestarts": + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=schedulers["CosineAnnealingWarmRestarts"]["T_0"], + eta_min=schedulers["CosineAnnealingWarmRestarts"]["eta_min"], + T_mult=schedulers["CosineAnnealingWarmRestarts"]["T_mult"], + verbose=verbose, + ) + elif scheduler_name == "CyclicLR": + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=schedulers["CyclicLR"]["base_lr"], + max_lr=schedulers["CyclicLR"]["max_lr"], + step_size_up=schedulers["CyclicLR"]["step_size_up"], + mode=schedulers["CyclicLR"]["mode"], + cycle_momentum=False, + ) + else: + raise ValueError(f"Scheduler {scheduler_name} not provided") + return scheduler + + +if __name__ == "__main__": + weights_location = str( + # Path(__file__).resolve().parent / "../weights/wnet.pth" + # "../wnet_SUM_MSE_DAPI_rad2_best_metric.pth" + ) + train( + # weights_location + ) diff --git a/napari_cellseg3d/code_models/model_workers.py b/napari_cellseg3d/code_models/workers.py similarity index 72% rename from napari_cellseg3d/code_models/model_workers.py rename to napari_cellseg3d/code_models/workers.py index 30d37bbd..50f85395 100644 --- a/napari_cellseg3d/code_models/model_workers.py +++ b/napari_cellseg3d/code_models/workers.py @@ -1,8 +1,9 @@ import platform +import time +import typing as t from dataclasses import dataclass from math import ceil from pathlib import Path -from typing import List, Optional import numpy as np import torch @@ -40,6 +41,7 @@ ) from monai.utils import set_determinism +# from napari.qt.threading import thread_worker # threads from napari.qt.threading import GeneratorWorker, WorkerBaseSignals @@ -51,31 +53,25 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.crf import crf_with_config +from napari_cellseg3d.code_models.instance_segmentation import ( ImageStats, volume_stats, ) logger = utils.LOGGER -""" -Writing something to log messages from outside the main thread is rather problematic (plenty of silent crashes...) -so instead, following the instructions in the guides below to have a worker with custom signals, I implemented -a custom worker function.""" - -# FutureReference(): -# https://python-forum.io/thread-31349.html -# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ -# https://napari-staging-site.github.io/guides/stable/threading.html - -WEIGHTS_DIR = Path(__file__).parent.resolve() / Path("models/pretrained") -logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {WEIGHTS_DIR}") +PRETRAINED_WEIGHTS_DIR = Path(__file__).parent.resolve() / Path( + "models/pretrained" +) +VERBOSE_SCHEDULER = True +logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") class WeightsDownloader: """A utility class the downloads the weights of a model when needed.""" - def __init__(self, log_widget: Optional[ui.Log] = None): + def __init__(self, log_widget: t.Optional[ui.Log] = None): """ Creates a WeightsDownloader, optionally with a log widget to display the progress. @@ -97,11 +93,11 @@ def download_weights(self, model_name: str, model_weights_filename: str): import tarfile import urllib.request - def show_progress(count, block_size, total_size): + def show_progress(_, block_size, __): # count, block_size, total_size pbar.update(block_size) logger.info("*" * 20) - pretrained_folder_path = WEIGHTS_DIR + pretrained_folder_path = PRETRAINED_WEIGHTS_DIR json_path = pretrained_folder_path / Path("pretrained_model_urls.json") check_path = pretrained_folder_path / Path(model_weights_filename) @@ -113,7 +109,7 @@ def show_progress(count, block_size, total_size): logger.info(message) return - with open(json_path) as f: + with Path.open(json_path) as f: neturls = json.load(f) if model_name in neturls: url = neturls[model_name] @@ -168,15 +164,31 @@ def safe_extract( ) +""" +Writing something to log messages from outside the main thread needs specific care, +Following the instructions in the guides below to have a worker with custom signals, +a custom worker function was implemented. +""" + +# https://python-forum.io/thread-31349.html +# https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ +# https://napari-staging-site.github.io/guides/stable/threading.html + + class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_""" # TODO link ? + Separate from Worker instances as indicated `on this post`_ + + .. _on this post: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + """ # TODO link ? log_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" warn_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" + error_signal = Signal(Exception, str) + """qtpy.QtCore.Signal: signal to be sent when some error should be emitted in main thread""" # Should not be an instance variable but a class variable, not defined in __init__, see # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect @@ -185,6 +197,44 @@ def __init__(self): super().__init__() +# TODO(cyril): move inference and training workers to separate files + + +class ONNXModelWrapper(torch.nn.Module): + """Class to replace torch model by ONNX Runtime session""" + + def __init__(self, file_location): + super().__init__() + try: + import onnxruntime as ort + except ImportError as e: + logger.error("ONNX is not installed but ONNX model was loaded") + logger.error(e) + msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]" + logger.error(msg) + raise ImportError(msg) from e + + self.ort_session = ort.InferenceSession( + file_location, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + + def forward(self, modeL_input): + """Wraps ONNX output in a torch tensor""" + outputs = self.ort_session.run( + None, {"input": modeL_input.cpu().numpy()} + ) + return torch.tensor(outputs[0]) + + def eval(self): + """Dummy function to replace model.eval()""" + pass + + def to(self, device): + """Dummy function to replace model.to(device)""" + pass + + @dataclass class InferenceResult: """Class to record results of a segmentation job""" @@ -192,7 +242,8 @@ class InferenceResult: image_id: int = 0 original: np.array = None instance_labels: np.array = None - stats: ImageStats = None + crf_results: np.array = None + stats: "np.array[ImageStats]" = None result: np.array = None model_name: str = None @@ -207,33 +258,24 @@ def __init__( ): """Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function. - Args: - * config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements - * device: cuda or cpu device to use for torch - - * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance - - * weights_dict: dict with "custom" : bool to use custom weights or not; "path" : the path to weights if custom or name of the file if not custom - - * results_path: the path to save the results to - - * filetype: the file extension to use when saving, - - * transforms: a dict containing transforms to perform at various times. - - * instance: a dict containing parameters regarding instance segmentation + The config contains the following attributes: + * device: cuda or cpu device to use for torch + * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance + * weights_dict: dict with "custom" : bool to use custom weights or not; "path" : the path to weights if custom or name of the file if not custom + * results_path: the path to save the results to + * filetype: the file extension to use when saving, + * transforms: a dict containing transforms to perform at various times. + * instance: a dict containing parameters regarding instance segmentation + * use_window: use window inference with specific size or whole image + * window_infer_size: size of window if use_window is True + * keep_on_cpu: keep images on CPU or no + * stats_csv: compute stats on cells and save them to a csv file + * images_filepaths: the paths to the images of the dataset + * layer: the layer to run inference on - * use_window: use window inference with specific size or whole image - - * window_infer_size: size of window if use_window is True - - * keep_on_cpu: keep images on CPU or no - - * stats_csv: compute stats on cells and save them to a csv file - - * images_filepaths: the paths to the images of the dataset + Args: + * worker_config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements - * layer: the layer to run inference on Note: See :py:func:`~self.inference` """ @@ -241,6 +283,7 @@ def __init__( self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal self.config = worker_config @@ -255,8 +298,7 @@ def create_inference_dict(images_filepaths): Returns: dict: list of image paths from loaded folder""" - data_dicts = [{"image": image_name} for image_name in images_filepaths] - return data_dicts + return [{"image": image_name} for image_name in images_filepaths] def set_download_log(self, widget): self.downloader.log_widget = widget @@ -273,6 +315,21 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Raises an error in main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + + self.log_signal.emit("!" * 20) + self.log_signal.emit("Error occured") + # self.log_signal.emit(msg) + # self.log_signal.emit(str(exception)) + + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + yield exception + # self.quit() + def log_parameters(self): config = self.config @@ -402,7 +459,7 @@ def load_layer(self): ) # for anisotropy to be monai-like, i.e. zyx # FIXME rotation not always correct dims_check = volume.shape - # self.log("\nChecking dimensions...") + self.log("Checking dimensions...") pad = utils.get_padding_dim(dims_check) # logger.debug(volume.shape) @@ -453,58 +510,67 @@ def model_output( # self.config.model_info.get_model().get_output(model, inputs) # ) - def model_output(inputs): - return post_process_transforms( - self.config.model_info.get_model().get_output(model, inputs) - ) - dataset_device = ( "cpu" if self.config.keep_on_cpu else self.config.device ) - window_size = self.config.sliding_window_config.window_size - window_overlap = self.config.sliding_window_config.window_overlap - - # FIXME - # import sys - - # old_stdout = sys.stdout - # old_stderr = sys.stderr - - # sys.stdout = self.downloader.log_widget - # sys.stdout = self.downloader.log_widget - - outputs = sliding_window_inference( - inputs, - roi_size=[window_size, window_size, window_size], - sw_batch_size=1, # TODO add param - predictor=model_output, - sw_device=self.config.device, - device=dataset_device, - overlap=window_overlap, - progress=True, - ) - - # sys.stdout = old_stdout - # sys.stderr = old_stderr - - out = outputs.detach().cpu() - - if aniso_transform is not None: - out = aniso_transform(out) - - if post_process: - out = np.array(out).astype(np.float32) - out = np.squeeze(out) - return out + if self.config.sliding_window_config.is_enabled(): + window_size = self.config.sliding_window_config.window_size + window_size = [window_size, window_size, window_size] + window_overlap = self.config.sliding_window_config.window_overlap else: + window_size = None + window_overlap = 0 + try: + # logger.debug(f"model : {model}") + logger.debug(f"inputs shape : {inputs.shape}") + logger.debug(f"inputs type : {inputs.dtype}") + try: + # outputs = model(inputs) + inputs = utils.remap_image(inputs) + + def model_output_wrapper(inputs): + result = model(inputs) + return post_process_transforms(result) + + with torch.no_grad(): + outputs = sliding_window_inference( + inputs, + roi_size=window_size, + sw_batch_size=1, # TODO add param + predictor=model_output_wrapper, + sw_device=self.config.device, + device=dataset_device, + overlap=window_overlap, + mode="gaussian", + sigma_scale=0.01, + progress=True, + ) + except Exception as e: + logger.exception(e) + logger.debug("failed to run sliding window inference") + self.raise_error(e, "Error during sliding window inference") + logger.debug(f"Inference output shape: {outputs.shape}") + self.log("Post-processing...") + out = outputs.detach().cpu().numpy() + if aniso_transform is not None: + out = aniso_transform(out) + if post_process: + out = np.array(out).astype(np.float32) + out = np.squeeze(out) return out + except Exception as e: + logger.exception(e) + self.raise_error(e, "Error during sliding window inference") + # sys.stdout = old_stdout + # sys.stderr = old_stderr - def create_result_dict( # FIXME replace with result class + def create_inference_result( self, semantic_labels, instance_labels, - from_layer: bool, + crf_results=None, + from_layer: bool = False, original=None, stats=None, i=0, @@ -519,12 +585,21 @@ def create_result_dict( # FIXME replace with result class raise ValueError( "A layer's ID should always be 0 (default value)" ) - semantic_labels = np.swapaxes(semantic_labels, 0, 2) + + if semantic_labels is not None: + semantic_labels = utils.correct_rotation(semantic_labels) + if crf_results is not None: + crf_results = utils.correct_rotation(crf_results) + if instance_labels is not None: + instance_labels = utils.correct_rotation( + instance_labels + ) # TODO(cyril) check if correct return InferenceResult( image_id=i + 1, original=original, instance_labels=instance_labels, + crf_results=crf_results, stats=stats, result=semantic_labels, model_name=self.config.model_info.name, @@ -544,10 +619,6 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): semantic_labels, i + 1, ) - if from_layer: - instance_labels = np.swapaxes( - instance_labels, 0, 2 - ) # TODO(cyril) check if correct data_dict = self.stats_csv(instance_labels) else: instance_labels = None @@ -559,24 +630,31 @@ def save_image( image, from_layer=False, i=0, + additional_info="", ): if not from_layer: original_filename = "_" + self.get_original_filename(i) + "_" + filetype = self.config.filetype else: original_filename = "_" + filetype = ".tif" time = utils.get_date_time() file_path = ( self.config.results_path + "/" + + f"{additional_info}" + f"Prediction_{i+1}" + original_filename + self.config.model_info.name - + f"_{time}_" - + self.config.filetype + + f"_{time}" + + filetype ) - imwrite(file_path, image) + try: + imwrite(file_path, image) + except ValueError as e: + self.raise_error(e, "Error during image saving") filename = Path(file_path).stem if from_layer: @@ -593,15 +671,23 @@ def aniso_transform(self, image): padding_mode="empty", ) return anisotropic_transform(image[0]) - else: - return image + return image - def instance_seg(self, to_instance, image_id=0, original_filename="layer"): + def instance_seg( + self, semantic_labels, image_id=0, original_filename="layer" + ): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") method = self.config.post_process_config.instance.method - instance_labels = method.run_method(image=to_instance) + instance_labels = method.run_method_on_channels(semantic_labels) + self.log(f"DEBUG instance results shape : {instance_labels.shape}") + + filetype = ( + ".tif" + if self.config.filetype == "" + else "_" + self.config.filetype + ) instance_filepath = ( self.config.results_path @@ -610,8 +696,8 @@ def instance_seg(self, to_instance, image_id=0, original_filename="layer"): + original_filename + "_" + self.config.model_info.name - + f"_{utils.get_date_time()}_" - + self.config.filetype + + f"_{utils.get_date_time()}" + + filetype ) imwrite(instance_filepath, instance_labels) @@ -636,31 +722,65 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): self.save_image(out, i=i) instance_labels, stats = self.get_instance_result(out, i=i) + if self.config.use_crf: + try: + crf_results = self.run_crf( + inputs, + out, + aniso_transform=self.aniso_transform, + image_id=i, + ) + + except ValueError as e: + self.log(f"Error occurred during CRF : {e}") + crf_results = None + else: + crf_results = None original = np.array(inf_data["image"]).astype(np.float32) self.log(f"Inference completed on image n°{i+1}") - return self.create_result_dict( + return self.create_inference_result( out, instance_labels, + crf_results, from_layer=False, original=original, stats=stats, i=i, ) + def run_crf(self, image, labels, aniso_transform, image_id=0): + try: + if aniso_transform is not None: + image = aniso_transform(image) + crf_results = crf_with_config( + image, labels, config=self.config.crf_config, log=self.log + ) + self.save_image( + crf_results, + i=image_id, + additional_info="CRF_", + from_layer=True, + ) + return crf_results + except ValueError as e: + self.log(f"Error occurred during CRF : {e}") + return None + def stats_csv(self, instance_labels): - if self.config.compute_stats: - stats = volume_stats( - instance_labels - ) # TODO test with area mesh function + try: + if self.config.compute_stats: + if len(instance_labels.shape) == 4: + stats = [volume_stats(c) for c in instance_labels] + else: + stats = [volume_stats(instance_labels)] + else: + stats = None return stats - - # except ValueError as e: - # self.log(f"Error occurred during stats computing : {e}") - # return None - else: + except ValueError as e: + self.log(f"Error occurred during stats computing : {e}") return None def inference_on_layer(self, image, model, post_process_transforms): @@ -678,15 +798,25 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels, stats = self.get_instance_result(out, from_layer=True) + instance_labels, stats = self.get_instance_result( + semantic_labels=out, from_layer=True + ) + + crf_results = ( + self.run_crf(image, out, aniso_transform=self.aniso_transform) + if self.config.use_crf + else None + ) - return self.create_result_dict( + return self.create_inference_result( semantic_labels=out, instance_labels=instance_labels, + crf_results=crf_results, from_layer=True, stats=stats, ) + # @thread_worker(connect={"errored": self.raise_error}) def inference(self): """ Requires: @@ -729,35 +859,77 @@ def inference(self): try: dims = self.config.model_info.model_input_size - # self.log(f"MODEL DIMS : {dims}") + self.log(f"MODEL DIMS : {dims}") model_name = self.config.model_info.name model_class = self.config.model_info.get_model() - self.log(model_name) + self.log(f"Model name : {model_name}") weights_config = self.config.weights_config post_process_config = self.config.post_process_config - - if model_name == "SegResNet": - model = model_class.get_net( - input_image_size=[ - dims, - dims, - dims, - ], # TODO FIX ! find a better way & remove model-specific code + if Path(weights_config.path).suffix == ".pt": + self.log("Instantiating PyTorch jit model...") + model = torch.jit.load(weights_config.path) + # try: + elif Path(weights_config.path).suffix == ".onnx": + self.log("Instantiating ONNX model...") + model = ONNXModelWrapper(weights_config.path) + else: # assume is .pth + self.log("Instantiating model...") + model = model_class( # FIXME test if works + input_img_size=[dims, dims, dims], + device=self.config.device, + num_classes=self.config.model_info.num_classes, ) - elif model_name == "SwinUNetR": - model = model_class.get_net( - img_size=[dims, dims, dims], - use_checkpoint=False, + # try: + model = model.to(self.config.device) + # except Exception as e: + # self.raise_error(e, "Issue loading model to device") + # logger.debug(f"model : {model}") + if model is None: + raise ValueError("Model is None") + # try: + self.log("\nLoading weights...") + if weights_config.custom: + weights = weights_config.path + else: + self.downloader.download_weights( + model_name, + model_class.weights_file, + ) + weights = str( + PRETRAINED_WEIGHTS_DIR / Path(model_class.weights_file) + ) + + model.load_state_dict( # note that this is redefined in WNet_ + torch.load( + weights, + map_location=self.config.device, + ) ) - else: - model = model_class.get_net() - model = model.to(self.config.device) + self.log("Done") + # except Exception as e: + # self.raise_error(e, "Issue loading weights") + # except Exception as e: + # self.raise_error(e, "Issue instantiating model") + + # if model_name == "SegResNet": + # model = model_class( + # input_image_size=[ + # dims, + # dims, + # dims, + # ], + # ) + # elif model_name == "SwinUNetR": + # model = model_class( + # img_size=[dims, dims, dims], + # use_checkpoint=False, + # ) + # else: + # model = model_class.get_net() self.log_parameters() - model.to(self.config.device) - # load_transforms = Compose( # [ # LoadImaged(keys=["image"]), @@ -778,25 +950,6 @@ def inference(self): AsDiscrete(threshold=t), EnsureType() ) - self.log("\nLoading weights...") - if weights_config.custom: - weights = weights_config.path - else: - self.downloader.download_weights( - model_name, - model_class.get_weights_file(), - ) - weights = str( - WEIGHTS_DIR / Path(model_class.get_weights_file()) - ) - model.load_state_dict( - torch.load( - weights, - map_location=self.config.device, - ) - ) - self.log("Done") - is_folder = self.config.images_filepaths is not None is_layer = self.config.layer is not None @@ -821,6 +974,9 @@ def inference(self): else: raise ValueError("No data has been provided. Aborting.") + if model is None: + raise ValueError("Model is None") + model.eval() with torch.no_grad(): ################################ @@ -836,9 +992,10 @@ def inference(self): input_image, model, post_process_transforms ) model.to("cpu") - + # self.quit() except Exception as e: - self.log(f"Error during inference : {e}") + logger.exception(e) + self.raise_error(e, "Inference failed") self.quit() finally: self.quit() @@ -848,10 +1005,10 @@ def inference(self): class TrainingReport: show_plot: bool = True epoch: int = 0 - loss_values: List = None - validation_metric: List = None + loss_values: t.Dict = None # TODO(cyril) : change to dict and unpack different losses for e.g. WNet with several losses + validation_metric: t.List = None weights: np.array = None - images: List[np.array] = None + images: t.List[np.array] = None class TrainingWorker(GeneratorWorker): @@ -860,7 +1017,7 @@ class TrainingWorker(GeneratorWorker): def __init__( self, - config: config.TrainingWorkerConfig, + worker_config: config.TrainingWorkerConfig, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train` @@ -903,10 +1060,11 @@ def __init__( self._signals = LogSignal() self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal self._weight_error = False ############################################# - self.config = config + self.config = worker_config self.train_files = [] self.val_files = [] @@ -928,6 +1086,14 @@ def warn(self, warning): """Sends a warning to main thread""" self.warn_signal.emit(warning) + def raise_error(self, exception, msg): + """Sends an error to main thread""" + logger.error(msg, exc_info=True) + logger.error(exception, exc_info=True) + self.error_signal.emit(exception, msg) + self.errored.emit(exception) + self.quit() + def log_parameters(self): self.log("-" * 20) self.log("Parameters summary :\n") @@ -1033,6 +1199,8 @@ def train(self): weights_config = self.config.weights_info deterministic_config = self.config.deterministic_config + start_time = time.time() + try: if deterministic_config.enabled: set_determinism( @@ -1057,23 +1225,11 @@ def train(self): do_sampling = self.config.sampling - if model_name == "SegResNet": - size = self.config.sample_size if do_sampling else check - logger.info(f"Size of image : {size}") - model = model_class.get_net( - input_image_size=utils.get_padding_dim(size), - # out_channels=1, - # dropout_prob=0.3, - ) - elif model_name == "SwinUNetR": - size = self.sample_size if do_sampling else check - logger.info(f"Size of image : {size}") - model = model_class.get_net( - img_size=utils.get_padding_dim(size), - use_checkpoint=True, - ) - else: - model = model_class.get_net() # get an instance of the model + size = self.config.sample_size if do_sampling else check + + model = model_class( # FIXME check if correct + input_img_size=utils.get_padding_dim(size), use_checkpoint=True + ) model = model.to(self.config.device) epoch_loss_values = [] @@ -1118,31 +1274,6 @@ def train(self): if len(self.val_files) == 0: raise ValueError("Validation dataset is empty") - if do_sampling: - sample_loader = Compose( - [ - LoadImaged(keys=["image", "label"]), - EnsureChannelFirstd(keys=["image", "label"]), - RandSpatialCropSamplesd( - keys=["image", "label"], - roi_size=( - self.config.sample_size - ), # multiply by axis_stretch_factor if anisotropy - # max_roi_size=(120, 120, 120), - random_size=False, - num_samples=self.config.num_samples, - ), - Orientationd(keys=["image", "label"], axcodes="PLI"), - SpatialPadd( - keys=["image", "label"], - spatial_size=( - utils.get_padding_dim(self.config.sample_size) - ), - ), - EnsureTyped(keys=["image", "label"]), - ] - ) - if self.config.do_augmentation: train_transforms = ( Compose( # TODO : figure out which ones and values ? @@ -1172,7 +1303,33 @@ def train(self): EnsureTyped(keys=["image", "label"]), ] ) + # self.log("Loading dataset...\n") + def get_loader_func(num_samples): + return Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + RandSpatialCropSamplesd( + keys=["image", "label"], + roi_size=( + self.config.sample_size + ), # multiply by axis_stretch_factor if anisotropy + # max_roi_size=(120, 120, 120), + random_size=False, + num_samples=num_samples, + ), + Orientationd(keys=["image", "label"], axcodes="PLI"), + SpatialPadd( + keys=["image", "label"], + spatial_size=( + utils.get_padding_dim(self.config.sample_size) + ), + ), + EnsureTyped(keys=["image", "label"]), + ] + ) + if do_sampling: # if there is only one volume, split samples # TODO(cyril) : maybe implement something in user config to toggle this behavior @@ -1185,11 +1342,16 @@ def train(self): self.config.num_samples * (1 - self.config.validation_percent) ) + sample_loader_train = get_loader_func(num_train_samples) + sample_loader_eval = get_loader_func(num_val_samples) else: num_train_samples = ( num_val_samples ) = self.config.num_samples + sample_loader_train = get_loader_func(num_train_samples) + sample_loader_eval = get_loader_func(num_val_samples) + logger.debug(f"AMOUNT of train samples : {num_train_samples}") logger.debug( f"AMOUNT of validation samples : {num_val_samples}" @@ -1199,21 +1361,25 @@ def train(self): train_ds = PatchDataset( data=self.train_files, transform=train_transforms, - patch_func=sample_loader, + patch_func=sample_loader_train, samples_per_image=num_train_samples, ) logger.debug("val_ds") val_ds = PatchDataset( data=self.val_files, transform=val_transforms, - patch_func=sample_loader, + patch_func=sample_loader_eval, samples_per_image=num_val_samples, ) else: load_whole_images = Compose( [ - LoadImaged(keys=["image", "label"]), + LoadImaged( + keys=["image", "label"], + # image_only=True, + # reader=WSIReader(backend="tifffile") + ), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="PLI"), SpatialPadd( @@ -1250,7 +1416,23 @@ def train(self): optimizer = torch.optim.Adam( model.parameters(), self.config.learning_rate ) - dice_metric = DiceMetric(include_background=True, reduction="mean") + + factor = self.config.scheduler_factor + if factor >= 1.0: + self.log(f"Warning : scheduler factor is {factor} >= 1.0") + self.log("Setting it to 0.5") + factor = 0.5 + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + mode="min", + factor=factor, + patience=self.config.scheduler_patience, + verbose=VERBOSE_SCHEDULER, + ) + dice_metric = DiceMetric( + include_background=False, reduction="mean" + ) best_metric = -1 best_metric_epoch = -1 @@ -1260,9 +1442,9 @@ def train(self): if weights_config.custom: if weights_config.use_pretrained: - weights_file = model_class.get_weights_file() + weights_file = model_class.weights_file self.downloader.download_weights(model_name, weights_file) - weights = WEIGHTS_DIR / Path(weights_file) + weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) weights_config.path = weights else: weights = str(Path(weights_config.path)) @@ -1276,6 +1458,7 @@ def train(self): ) except RuntimeError as e: logger.error(f"Error when loading weights : {e}") + logger.exception(e) warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" @@ -1294,9 +1477,9 @@ def train(self): device = self.config.device - if model_name == "test": - self.quit() - yield TrainingReport(False) + # if model_name == "test": + # self.quit() + # yield TrainingReport(False) for epoch in range(self.config.max_epochs): # self.log("\n") @@ -1323,8 +1506,14 @@ def train(self): batch_data["label"].to(device), ) optimizer.zero_grad() - outputs = model_class.get_output(model, inputs) + outputs = model(inputs) # self.log(f"Output dimensions : {outputs.shape}") + if outputs.shape[1] > 1: + outputs = outputs[ + :, 1:, :, : + ] # FIXME fix channel number + if len(outputs.shape) < 4: + outputs = outputs.unsqueeze(0) loss = self.config.loss_function(outputs, labels) loss.backward() optimizer.step() @@ -1341,23 +1530,47 @@ def train(self): epoch_loss_values.append(epoch_loss) self.log(f"Epoch: {epoch + 1}, Average loss: {epoch_loss:.4f}") + self.log("Updating scheduler...") + scheduler.step(epoch_loss) + checkpoint_output = [] + eta = ( + (time.time() - start_time) + * (self.config.max_epochs / (epoch + 1) - 1) + / 60 + ) + self.log("ETA: " + f"{eta:.2f}" + " minutes") if ( (epoch + 1) % self.config.validation_interval == 0 or epoch + 1 == self.config.max_epochs ): model.eval() + self.log("Performing validation...") with torch.no_grad(): for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) - - val_outputs = model_class.get_validation( - model, val_inputs + try: + with torch.no_grad(): + val_outputs = sliding_window_inference( + val_inputs, + roi_size=size, + sw_batch_size=self.config.batch_size, + predictor=model, + overlap=0.25, + sw_device=self.config.device, + device=self.config.device, + progress=False, + ) + except Exception as e: + self.raise_error(e, "Error during validation") + logger.debug( + f"val_outputs shape : {val_outputs.shape}" ) + # val_outputs = model(val_inputs) pred = decollate_batch(val_outputs) @@ -1404,7 +1617,7 @@ def train(self): weights=model.state_dict(), images=checkpoint_output, ) - + self.log("Validation completed") yield train_report weights_filename = ( @@ -1437,7 +1650,7 @@ def train(self): model.to("cpu") except Exception as e: - self.log(f"Error in training : {e}") + self.raise_error(e, "Error in training") self.quit() finally: self.quit() diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py index 0a613ee7..cfa3f0d7 100644 --- a/napari_cellseg3d/code_plugins/plugin_base.py +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from pathlib import Path @@ -47,15 +46,15 @@ def __init__( self.image_path = None """str: path to image folder""" - self.show_image_io = loads_images + self._show_image_io = loads_images self.label_path = None """str: path to label folder""" - self.show_label_io = loads_labels + self._show_label_io = loads_labels self.results_path = None """str: path to results folder""" - self.show_results_io = has_results + self._show_results_io = has_results self._default_path = [self.image_path, self.label_path] @@ -99,7 +98,7 @@ def __init__( ) self.filetype_choice = ui.DropdownMenu( - [".tif", ".tiff"], label="File format" + [".tif", ".tiff"], text_label="File format" ) ######## qInstallMessageHandler(ui.handle_adjust_errors_wrapper(self)) @@ -118,7 +117,6 @@ def show_menu(_, event): def _build_io_panel(self): self.io_panel = ui.GroupedWidget("Data") self.save_label = ui.make_label("Save location :", parent=self) - # self.io_panel.setToolTip("IO Panel") ui.add_widgets( @@ -140,25 +138,25 @@ def _build_io_panel(self): return self.io_panel def _remove_unused(self): - if not self.show_label_io: + if not self._show_label_io: self.labels_filewidget = None self.label_layer_loader = None - if not self.show_image_io: + if not self._show_image_io: self.image_layer_loader = None self.image_filewidget = None - if not self.show_results_io: + if not self._show_results_io: self.results_filewidget = None def _set_io_visibility(self): ################## # Show when layer is selected - if self.show_image_io: + if self._show_image_io: self._show_io_element(self.image_layer_loader, self.layer_choice) else: self._hide_io_element(self.image_layer_loader) - if self.show_label_io: + if self._show_label_io: self._show_io_element(self.label_layer_loader, self.layer_choice) else: self._hide_io_element(self.label_layer_loader) @@ -168,15 +166,15 @@ def _set_io_visibility(self): f = self.folder_choice self._show_io_element(self.filetype_choice, f) - if self.show_image_io: + if self._show_image_io: self._show_io_element(self.image_filewidget, f) else: self._hide_io_element(self.image_filewidget) - if self.show_label_io: + if self._show_label_io: self._show_io_element(self.labels_filewidget, f) else: self._hide_io_element(self.labels_filewidget) - if not self.show_results_io: + if not self._show_results_io: self._hide_io_element(self.results_filewidget) self.folder_choice.toggle() @@ -229,17 +227,16 @@ def _show_filetype_choice(self): def _show_file_dialog(self): """Open file dialog and process path depending on single file/folder loading behaviour""" if self.load_as_stack_choice.isChecked(): - folder = ui.open_folder_dialog( + choice = ui.open_folder_dialog( self, self._default_path, filetype=f"Image file (*{self.filetype_choice.currentText()})", ) - return folder else: f_name = ui.open_file_dialog(self, self._default_path) - f_name = str(f_name[0]) - self.filetype = str(Path(f_name).suffix) - return f_name + choice = str(f_name[0]) + self.filetype = str(Path(choice).suffix) + return choice def _show_dialog_images(self): """Show file dialog and set image path""" @@ -293,16 +290,14 @@ def _make_close_button(self): return btn def _make_prev_button(self): - btn = ui.Button( + return ui.Button( "Previous", lambda: self.setCurrentIndex(self.currentIndex() - 1) ) - return btn def _make_next_button(self): - btn = ui.Button( + return ui.Button( "Next", lambda: self.setCurrentIndex(self.currentIndex() + 1) ) - return btn def remove_from_viewer(self): """Removes the widget from the napari window. @@ -404,7 +399,7 @@ def load_dataset_paths(self): file_paths = sorted(Path(directory).glob("*" + filetype)) if len(file_paths) == 0: - warnings.warn( + logger.warning( f"The folder does not contain any compatible {filetype} files.\n" f"Please check the validity of the folder and images." ) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 6c8370c1..4357e51e 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,14 +1,13 @@ -import warnings from pathlib import Path import napari import numpy as np from qtpy.QtWidgets import QSizePolicy -from tifffile import imread, imwrite +from tifffile import imread import napari_cellseg3d.interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceWidgets, clear_small_objects, threshold, @@ -16,80 +15,12 @@ ) from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder -# TODO break down into multiple mini-widgets -# TODO create parent class for utils modules to avoid duplicates - -MAX_W = 200 -MAX_H = 1000 +MAX_W = ui.UTILS_MAX_WIDTH +MAX_H = ui.UTILS_MAX_HEIGHT logger = utils.LOGGER -def save_folder(results_path, folder_name, images, image_paths): - """ - Saves a list of images in a folder - - Args: - results_path: Path to the folder containing results - folder_name: Name of the folder containing results - images: List of images to save - image_paths: list of filenames of images - """ - results_folder = results_path / Path(folder_name) - results_folder.mkdir(exist_ok=False, parents=True) - - for file, image in zip(image_paths, images): - path = results_folder / Path(file).name - - imwrite( - path, - image, - ) - logger.info(f"Saved processed folder as : {results_folder}") - - -def save_layer(results_path, image_name, image): - """ - Saves an image layer at the specified path - - Args: - results_path: path to folder containing result - image_name: image name for saving - image: data array containing image - - Returns: - - """ - path = str(results_path / Path(image_name)) # TODO flexible filetype - logger.info(f"Saved as : {path}") - imwrite(path, image) - - -def show_result(viewer, layer, image, name): - """ - Adds layers to a viewer to show result to user - - Args: - viewer: viewer to add layer in - layer: type of the original layer the operation was run on, to determine whether it should be an Image or Labels layer - image: the data array containing the image - name: name of the added layer - - Returns: - - """ - if isinstance(layer, napari.layers.Image): - logger.debug("Added resulting image layer") - viewer.add_image(image, name=name) - elif isinstance(layer, napari.layers.Labels): - logger.debug("Added resulting label layer") - viewer.add_labels(image, name=name) - else: - warnings.warn( - f"Results not shown, unsupported layer type {type(layer)}" - ) - - class AnisoUtils(BasePluginFolder): """Class to correct anisotropy in images""" @@ -115,7 +46,7 @@ def __init__(self, viewer: "napari.Viewer.viewer", parent=None): self.aniso_widgets = ui.AnisotropyWidgets(self, always_visible=True) self.start_btn = ui.Button("Start", self._start) - self.results_path = Path.home() / Path("cellseg3d/anisotropy") + self.results_path = str(Path.home() / Path("cellseg3d/anisotropy")) self.results_filewidget.text_field.setText(str(self.results_path)) self.results_filewidget.check_ready() @@ -145,7 +76,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) zoom = self.aniso_widgets.scaling_zyx() if self.layer_choice.isChecked(): @@ -155,12 +86,12 @@ def _start(self): data = np.array(layer.data) isotropic_image = utils.resize(data, zoom) - save_layer( + utils.save_layer( self.results_path, f"isotropic_{layer.name}_{utils.get_date_time()}.tif", isotropic_image, ) - show_result( + utils.show_result( self._viewer, layer, isotropic_image, @@ -174,7 +105,7 @@ def _start(self): utils.resize(np.array(imread(file)), zoom) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"isotropic_results_{utils.get_date_time()}", images, @@ -211,7 +142,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): lower=1, upper=100000, default=10, - label="Remove all smaller than (pxs):", + text_label="Remove all smaller than (pxs):", ) self.results_path = Path.home() / Path("cellseg3d/small_removed") @@ -245,7 +176,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.size_for_removal_counter.value() if self.layer_choice: @@ -255,12 +186,12 @@ def _start(self): data = np.array(layer.data) removed = self.function(data, remove_size) - save_layer( + utils.save_layer( self.results_path, f"cleared_{layer.name}_{utils.get_date_time()}.tif", removed, ) - show_result( + utils.show_result( self._viewer, layer, removed, f"cleared_{layer.name}" ) elif ( @@ -270,7 +201,7 @@ def _start(self): clear_small_objects(file, remove_size, is_file_path=True) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"small_removed_results_{utils.get_date_time()}", images, @@ -337,12 +268,12 @@ def _start(self): data = np.array(layer.data) semantic = to_semantic(data) - save_layer( + utils.save_layer( self.results_path, f"semantic_{layer.name}_{utils.get_date_time()}.tif", semantic, ) - show_result( + utils.show_result( self._viewer, layer, semantic, f"semantic_{layer.name}" ) elif ( @@ -352,7 +283,7 @@ def _start(self): to_semantic(file, is_file_path=True) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"semantic_results_{utils.get_date_time()}", images, @@ -414,7 +345,7 @@ def _build(self): ) def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) if self.layer_choice: if self.label_layer_loader.layer_data() is not None: @@ -423,7 +354,7 @@ def _start(self): data = np.array(layer.data) instance = self.instance_widgets.run_method(data) - save_layer( + utils.save_layer( self.results_path, f"instance_{layer.name}_{utils.get_date_time()}.tif", instance, @@ -436,10 +367,10 @@ def _start(self): self.folder_choice.isChecked() and len(self.images_filepaths) != 0 ): images = [ - self.instance_widgets.run_method(imread(file)) + self.instance_widgets.run_method_on_channels(imread(file)) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"instance_results_{utils.get_date_time()}", images, @@ -474,7 +405,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): upper=100000.0, step=0.5, default=10.0, - label="Remove all smaller than (value):", + text_label="Remove all smaller than (value):", ) self.results_path = Path.home() / Path("cellseg3d/threshold") @@ -509,7 +440,7 @@ def _build(self): return container def _start(self): - self.results_path.mkdir(exist_ok=True, parents=True) + utils.mkdir_from_str(self.results_path) remove_size = self.binarize_counter.value() if self.layer_choice: @@ -519,12 +450,12 @@ def _start(self): data = np.array(layer.data) removed = self.function(data, remove_size) - save_layer( + utils.save_layer( self.results_path, f"threshold_{layer.name}_{utils.get_date_time()}.tif", removed, ) - show_result( + utils.show_result( self._viewer, layer, removed, f"threshold{layer.name}" ) elif ( @@ -534,7 +465,7 @@ def _start(self): self.function(imread(file), remove_size) for file in self.images_filepaths ] - save_folder( + utils.save_folder( self.results_path, f"threshold_results_{utils.get_date_time()}", images, diff --git a/napari_cellseg3d/code_plugins/plugin_crf.py b/napari_cellseg3d/code_plugins/plugin_crf.py new file mode 100644 index 00000000..76194e87 --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_crf.py @@ -0,0 +1,290 @@ +import contextlib +from functools import partial +from pathlib import Path + +import napari.layers +from qtpy.QtWidgets import QSizePolicy +from tqdm import tqdm + +from napari_cellseg3d import config, utils +from napari_cellseg3d import interface as ui +from napari_cellseg3d.code_models.crf import ( + CRF_INSTALLED, + CRFWorker, + crf_with_config, +) +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage +from napari_cellseg3d.utils import LOGGER as logger + + +# TODO add CRF on folder +class CRFParamsWidget(ui.GroupedWidget): + """Use this widget when adding the crf as part of another widget (rather than a standalone widget)""" + + def __init__(self, parent=None): + super().__init__(title="CRF parameters", parent=parent) + ####### + # CRF params # + self.sa_choice = ui.DoubleIncrementCounter( + default=10, parent=self, text_label="Alpha std" + ) + self.sb_choice = ui.DoubleIncrementCounter( + default=5, parent=self, text_label="Beta std" + ) + self.sg_choice = ui.DoubleIncrementCounter( + default=1, parent=self, text_label="Gamma std" + ) + self.w1_choice = ui.DoubleIncrementCounter( + default=10, parent=self, text_label="Weight appearance" + ) + self.w2_choice = ui.DoubleIncrementCounter( + default=5, parent=self, text_label="Weight smoothness" + ) + self.n_iter_choice = ui.IntIncrementCounter( + default=5, parent=self, text_label="Number of iterations" + ) + ####### + self._build() + self._set_tooltips() + + def _build(self): + if not CRF_INSTALLED: + ui.add_widgets( + self.layout, + [ + ui.make_label( + "ERROR: CRF not installed.\nPlease refer to the documentation to install it." + ), + ], + ) + self.set_layout() + return + ui.add_widgets( + self.layout, + [ + # self.sa_choice.label, + self.sa_choice, + # self.sb_choice.label, + self.sb_choice, + # self.sg_choice.label, + self.sg_choice, + # self.w1_choice.label, + self.w1_choice, + # self.w2_choice.label, + self.w2_choice, + # self.n_iter_choice.label, + self.n_iter_choice, + ], + ) + self.set_layout() + + def _set_tooltips(self): + self.sa_choice.setToolTip( + "SA : Standard deviation of the Gaussian kernel in the appearance term." + ) + self.sb_choice.setToolTip( + "SB : Standard deviation of the Gaussian kernel in the smoothness term." + ) + self.sg_choice.setToolTip( + "SG : Standard deviation of the Gaussian kernel in the gradient term." + ) + self.w1_choice.setToolTip( + "W1 : Weight of the appearance term in the CRF." + ) + self.w2_choice.setToolTip( + "W2 : Weight of the smoothness term in the CRF." + ) + self.n_iter_choice.setToolTip("Number of iterations of the CRF.") + + def make_config(self): + return config.CRFConfig( + sa=self.sa_choice.value(), + sb=self.sb_choice.value(), + sg=self.sg_choice.value(), + w1=self.w1_choice.value(), + w2=self.w2_choice.value(), + n_iters=self.n_iter_choice.value(), + ) + + +class CRFWidget(BasePluginSingleImage): + def __init__(self, viewer, parent=None): + """ + Create a widget for CRF post-processing. + Args: + viewer: napari viewer to display the widget + parent: parent widget. Defaults to None. + """ + super().__init__(viewer, parent) + self._viewer = viewer + + self.start_button = ui.Button("Start", self._start, parent=self) + self.crf_params_widget = CRFParamsWidget(parent=self) + self.io_panel = self._build_io_panel() + self.io_panel.setVisible(False) + + self.results_filewidget.setVisible(True) + self.label_layer_loader.setVisible(True) + self.label_layer_loader.set_layer_type( + napari.layers.Image + ) # to load all crf-compatible inputs, not int only + self.image_layer_loader.setVisible(True) + if CRF_INSTALLED: + self.start_button.setVisible(True) + else: + self.start_button.setVisible(False) + + self.result_layer = None + self.result_name = None + self.crf_results = [] + + self.results_path = Path.home() / Path("cellseg3d/crf") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self._container = ui.ContainerWidget(parent=self, l=11, t=11, r=11) + self.layout = self._container.layout + + self._build() + + self.worker = None + self.log = None + + def _build(self): + self.setMinimumWidth(100) + ui.add_widgets( + self.layout, + [ + self.image_layer_loader, + self.label_layer_loader, + self.save_label, + self.results_filewidget, + ui.make_label(""), + self.crf_params_widget, + ui.make_label(""), + self.start_button, + ], + ) + # self.io_panel.setLayout(self.io_panel.layout) + self.setLayout(self.layout) + + ui.ScrollArea.make_scrollable( + self.layout, self, max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT] + ) + self._container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + return self._container + + def make_config(self): + return self.crf_params_widget.make_config() + + def print_config(self): + logger.info("CRF config:") + for item in self.make_config().__dict__.items(): + logger.info(f"{item[0]}: {item[1]}") + + def _check_ready(self): + if len(self.label_layer_loader.layer_list) < 1: + logger.warning("No label layer loaded") + return False + if len(self.image_layer_loader.layer_list) < 1: + logger.warning("No image layer loaded") + return False + + if len(self.label_layer_loader.layer_data().shape) < 3: + logger.warning("Label layer must be 3D") + return False + if len(self.image_layer_loader.layer_data().shape) < 3: + logger.warning("Image layer must be 3D") + return False + if ( + self.label_layer_loader.layer_data().shape[-3:] + != self.image_layer_loader.layer_data().shape[-3:] + ): + logger.warning("Image and label layers must have the same shape!") + return False + + return True + + def run_crf_on_batch(self, images_list: list, labels_list: list, log=None): + self.crf_results = [] + for image, label in zip(images_list, labels_list): + tqdm( + unit="B", + total=len(images_list), + position=0, + file=log, + ) + result = crf_with_config(image, label, self.make_config()) + self.crf_results.append(result) + return self.crf_results + + def _prepare_worker(self, images_list: list, labels_list: list): + self.worker = CRFWorker( + images_list=images_list, + labels_list=labels_list, + config=self.make_config(), + ) + + self.worker.started.connect(self._on_start) + self.worker.yielded.connect(partial(self._on_yield)) + self.worker.errored.connect(partial(self._on_error)) + self.worker.finished.connect(self._on_finish) + + def _start(self): + if not self._check_ready(): + return + + self.result_layer = self.label_layer_loader.layer() + self.result_name = self.label_layer_loader.layer_name() + + utils.mkdir_from_str(self.results_path) + + image_list = [self.image_layer_loader.layer_data()] + labels_list = [self.label_layer_loader.layer_data()] + [logger.debug(f"Image shape: {image.shape}") for image in image_list] + [ + logger.debug(f"Label shape: {labels.shape}") + for labels in labels_list + ] + + self._prepare_worker(image_list, labels_list) + + if self.worker.is_running: # if worker is running, tries to stop + logger.info("Stop request, waiting for previous job to finish") + self.start_button.setText("Stopping...") + self.worker.quit() + else: # once worker is started, update buttons + self.start_button.setText("Running...") + logger.info("Starting CRF...") + self.worker.start() + + def _on_yield(self, result): + self.crf_results.append(result) + + utils.save_layer( + self.results_filewidget.text_field.text(), + str(self.result_name + "_crf.tif"), + result, + ) + self._viewer.add_image( + result, + name="crf_" + self.result_name, + ) + + def _on_start(self): + self.crf_results = [] + + def _on_finish(self): + self.worker = None + with contextlib.suppress(RuntimeError): + self.start_button.setText("Start") + + # should only happen when testing + + def _on_error(self, error): + logger.error(error) + self.start_button.setText("Start") + self.worker.quit() + self.worker = None diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py index 9830d51e..74691e1f 100644 --- a/napari_cellseg3d/code_plugins/plugin_crop.py +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -1,4 +1,4 @@ -import warnings +from math import floor from pathlib import Path import napari @@ -44,13 +44,16 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.image_layer_loader.set_layer_type(napari.layers.Layer) self.image_layer_loader.layer_list.label.setText("Image 1") + self.image_layer_loader.layer_list.currentIndexChanged.connect( + self.auto_set_dims + ) # ui.LayerSelecter(self._viewer, "Image 1") # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") self.label_layer_loader.set_layer_type(napari.layers.Layer) self.label_layer_loader.layer_list.label.setText("Image 2") self.crop_second_image_choice = ui.CheckBox( - "Crop another\nimage simultaneously", + "Crop another\nimage/label simultaneously", ) self.crop_second_image_choice.toggled.connect( self._toggle_second_image_io_visibility @@ -81,7 +84,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.results_filewidget.check_ready() self.crop_size_widgets = ui.IntIncrementCounter.make_n( - 3, 1, 1000, DEFAULT_CROP_SIZE + 3, 1, 10000, DEFAULT_CROP_SIZE ) self.crop_size_labels = [ ui.make_label("Size in " + axis + " of cropped volume :", self) @@ -113,6 +116,8 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self._build() self._toggle_second_image_io_visibility() + self._check_image_list() + self.auto_set_dims() def _toggle_second_image_io_visibility(self): crop_2nd = self.crop_second_image_choice.isChecked() @@ -133,6 +138,18 @@ def _check_image_list(self): except IndexError: return + def auto_set_dims(self): + logger.debug(self.image_layer_loader.layer_name()) + data = self.image_layer_loader.layer_data() + if data is not None: + logger.debug(f"auto_set_dims : {data.shape}") + if len(data.shape) == 3: + for i, box in enumerate(self.crop_size_widgets): + logger.debug( + f"setting dim {i} to {floor(data.shape[i]/2)}" + ) + box.setValue(floor(data.shape[i] / 2)) + def _build(self): """Build buttons in a layout and add them to the napari Viewer""" @@ -158,8 +175,10 @@ def _build(self): dim_group_l.addWidget(self.aniso_widgets) [ dim_group_l.addWidget(widget, alignment=ui.ABS_AL) - for list in zip(self.crop_size_labels, self.crop_size_widgets) - for widget in list + for widget_list in zip( + self.crop_size_labels, self.crop_size_widgets + ) + for widget in widget_list ] dim_group_w.setLayout(dim_group_l) layout.addWidget(dim_group_w) @@ -175,7 +194,12 @@ def _build(self): ], ) - ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 200]) + ui.ScrollArea.make_scrollable( + layout, + self, + max_wh=[ui.UTILS_MAX_WIDTH, ui.UTILS_MAX_HEIGHT], + min_wh=[200, 200], + ) self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.MinimumExpanding) self._set_io_visibility() @@ -245,7 +269,7 @@ def _start(self): # maybe use singletons or make docked widgets attributes that are hidden upon opening if not self._check_ready(): - warnings.warn("Please select at least one valid layer !") + logger.warning("Please select at least one valid layer !") return # self._viewer.window.remove_dock_widget(self.parent()) # no need to close utils ? @@ -260,9 +284,9 @@ def _start(self): except ValueError as e: logger.warning(e) logger.warning( - "Could not remove cropping layer programmatically!" + "Could not remove the previous cropping layer programmatically." ) - logger.warning("Maybe layer has been removed by user?") + # logger.warning("Maybe layer has been removed by user?") self.results_path = Path(self.results_filewidget.text_field.text()) @@ -329,7 +353,7 @@ def add_isotropic_layer( self, layer, colormap="inferno", - contrast_lim=[200, 1000], # TODO generalize ? + contrast_lim=(200, 1000), # TODO generalize ? opacity=0.7, visible=True, ): @@ -340,7 +364,7 @@ def add_isotropic_layer( layer.data, name=f"Scaled_{layer.name}", colormap=colormap, - contrast_limits=contrast_lim, + # contrast_limits=contrast_lim, opacity=opacity, scale=self.aniso_factors, visible=visible, @@ -413,9 +437,8 @@ def _add_crop_sliders( box.value() for box in self.crop_size_widgets ] ############# - dims = [self._x, self._y, self._z] - [logger.debug(f"{dim}") for dim in dims] - logger.debug("SET DIMS ATTEMPT") + # [logger.debug(f"{dim}") for dim in dims] + # logger.debug("SET DIMS ATTEMPT") # if not self.create_new_layer.isChecked(): # self._x = x # self._y = y @@ -431,6 +454,8 @@ def _add_crop_sliders( # define crop sizes and boundaries for the image crop_sizes = [self._crop_size_x, self._crop_size_y, self._crop_size_z] + # [logger.debug(f"{crop}") for crop in crop_sizes] + # logger.debug("SET CROP ATTEMPT") for i in range(len(crop_sizes)): if crop_sizes[i] > im1_stack.shape[i]: @@ -475,8 +500,8 @@ def set_slice( """ "Update cropped volume position""" # self._check_for_empty_layer(highres_crop_layer, highres_crop_layer.data) - logger.debug(f"axis : {axis}") - logger.debug(f"value : {value}") + # logger.debug(f"axis : {axis}") + # logger.debug(f"value : {value}") idx = int(value) scale = np.asarray(highres_crop_layer.scale) @@ -490,6 +515,20 @@ def set_slice( cropy = self._crop_size_y cropz = self._crop_size_z + if i + cropx > im1_stack.shape[0]: + cropx = im1_stack.shape[0] - i + if j + cropy > im1_stack.shape[1]: + cropy = im1_stack.shape[1] - j + if k + cropz > im1_stack.shape[2]: + cropz = im1_stack.shape[2] - k + + logger.debug(f"cropx : {cropx}") + logger.debug(f"cropy : {cropy}") + logger.debug(f"cropz : {cropz}") + logger.debug(f"i : {i}") + logger.debug(f"j : {j}") + logger.debug(f"k : {k}") + highres_crop_layer.data = im1_stack[ i : i + cropx, j : j + cropy, k : k + cropz ] @@ -533,7 +572,7 @@ def set_slice( # container_widget.extend(sliders) ui.add_widgets( container_widget.layout, - [ui.combine_blocks(s, s.text_label) for s in sliders], + [ui.combine_blocks(s, s.label) for s in sliders], ) # vw.window.add_dock_widget([spinbox, container_widget], area="right") wdgts = vw.window.add_dock_widget( diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index f8ac18ef..552f70ea 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -1,6 +1,8 @@ import pathlib +from typing import TYPE_CHECKING -import napari +if TYPE_CHECKING: + import napari # Qt from qtpy.QtCore import QSize @@ -37,7 +39,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.2rc6'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.3rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/napari_cellseg3d/code_plugins/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py index 114025f6..1dc5e7de 100644 --- a/napari_cellseg3d/code_plugins/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -1,5 +1,6 @@ +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt -import napari import numpy as np from matplotlib.backends.backend_qt5agg import ( FigureCanvasQTAgg as FigureCanvas, @@ -8,9 +9,12 @@ from monai.transforms import SpatialPad, ToTensor from tifffile import imread +if TYPE_CHECKING: + import napari + from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_models.instance_segmentation import to_semantic from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder DEFAULT_THRESHOLD = 0.5 @@ -19,7 +23,7 @@ class MetricsUtils(BasePluginFolder): """Plugin to evaluate metrics between two sets of labels, ground truth and prediction""" - def __init__(self, viewer: "napari.viewer.Viewer", parent): + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): """Creates a MetricsUtils widget for computing and plotting dice metrics between labels. Args: viewer: viewer to display the widget in @@ -187,11 +191,11 @@ def compute_dice(self): self.canvas = ( None # kind of terrible way to stack plots... but it works. ) - id = 0 + image_id = 0 for ground_path, pred_path in zip( self.images_filepaths, self.labels_filepaths ): - id += 1 + image_id += 1 ground = imread(ground_path) pred = imread(pred_path) diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 22867343..256cffa4 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,22 +1,27 @@ -import warnings from functools import partial +from typing import TYPE_CHECKING -import napari import numpy as np import pandas as pd +if TYPE_CHECKING: + import napari + # local from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui -from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_instance_seg import ( +from napari_cellseg3d.code_models.instance_segmentation import ( InstanceMethod, InstanceWidgets, ) -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.workers import ( InferenceResult, InferenceWorker, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFParamsWidget + +logger = utils.LOGGER class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): @@ -109,11 +114,14 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ###################### # TODO : better way to handle SegResNet size reqs ? self.model_input_size = ui.IntIncrementCounter( - lower=1, upper=1024, default=128, label="\nModel input size" + lower=1, upper=1024, default=128, text_label="\nModel input size" ) self.model_choice.currentIndexChanged.connect( self._toggle_display_model_input_size ) + self.model_choice.currentIndexChanged.connect( + self._restrict_window_size_for_model + ) self.model_choice.setCurrentIndex(0) self.anisotropy_wdgt = ui.AnisotropyWidgets( @@ -138,7 +146,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.thresholding_slider = ui.Slider( - lower=1, default=config.PostProcessConfig().thresholding.threshold_value * 100, divide_factor=100.0, @@ -146,9 +153,10 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.window_infer_box = ui.CheckBox("Use window inference") - self.window_infer_box.clicked.connect(self._toggle_display_window_size) + self.window_infer_box.toggled.connect(self._toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] + self._default_window_size = sizes_window.index("64") # ( # self.window_size_choice, # self.window_size_choice.label, @@ -161,8 +169,11 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): # ) self.window_size_choice = ui.DropdownMenu( - sizes_window, label="Window size" + sizes_window, text_label="Window size" ) + self.window_size_choice.setCurrentIndex( + self._default_window_size + ) # set to 64 by default self.window_overlap_slider = ui.Slider( default=config.SlidingWindowConfig.window_overlap * 100, @@ -187,15 +198,22 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) - self.window_size_choice.setCurrentIndex(3) # default size to 64 ################## ################## # instance segmentation widgets self.instance_widgets = InstanceWidgets(parent=self) + self.crf_widgets = CRFParamsWidget(parent=self) self.use_instance_choice = ui.CheckBox( - "Run instance segmentation", func=self._toggle_display_instance + "Run instance segmentation", + func=self._toggle_display_instance, + parent=self, + ) + self.use_crf = ui.CheckBox( + "Use CRF post-processing", + func=self._toggle_display_crf, + parent=self, ) self.save_stats_to_csv_box = ui.CheckBox( @@ -286,6 +304,19 @@ def check_ready(self): return True return False + def _restrict_window_size_for_model(self): + """Sets the window size to a value that is compatible with the chosen model""" + if self.model_choice.currentText() == "WNet": + self.window_size_choice.setCurrentIndex(self._default_window_size) + self.window_size_choice.setDisabled(True) + self.window_infer_box.setChecked(True) + self.window_infer_box.setDisabled(True) + else: + self.window_size_choice.setDisabled(False) + self.window_infer_box.setDisabled(False) + self.window_infer_box.setChecked(False) + self.window_size_choice.setCurrentIndex(self._default_window_size) + def _toggle_display_model_input_size(self): if ( self.model_choice.currentText() == "SegResNet" @@ -307,6 +338,10 @@ def _toggle_display_thresh(self): self.thresholding_checkbox, self.thresholding_slider.container ) + def _toggle_display_crf(self): + """Shows the choices for CRF post-processing depending on whether :py:attr:`self.use_crf` is checked""" + ui.toggle_visibility(self.use_crf, self.crf_widgets) + def _toggle_display_instance(self): """Shows or hides the options for instance segmentation based on current user selection""" ui.toggle_visibility(self.use_instance_choice, self.instance_widgets) @@ -315,6 +350,18 @@ def _toggle_display_window_size(self): """Show or hide window size choice depending on status of self.window_infer_box""" ui.toggle_visibility(self.window_infer_box, self.window_infer_params) + def _load_weights_path(self): + """Show file dialog to set :py:attr:`model_path`""" + + # logger.debug(self._default_weights_folder) + + file = ui.open_file_dialog( + self, + [self._default_weights_folder], + file_extension="Weights file (*.pth *.pt *.onnx)", + ) + self._update_weights_path(file) + def _build(self): """Puts all widgets in a layout and adds them to the napari Viewer""" @@ -422,6 +469,8 @@ def _build(self): self.anisotropy_wdgt, # anisotropy self.thresholding_checkbox, self.thresholding_slider.container, # thresholding + self.use_crf, + self.crf_widgets, self.use_instance_choice, self.instance_widgets, self.save_stats_to_csv_box, @@ -435,6 +484,7 @@ def _build(self): self.anisotropy_wdgt.container.setVisible(False) self.thresholding_slider.container.setVisible(False) self.instance_widgets.setVisible(False) + self.crf_widgets.setVisible(False) self.save_stats_to_csv_box.setVisible(False) post_proc_group.setLayout(post_proc_layout) @@ -531,64 +581,7 @@ def start(self): self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - self.model_info = config.ModelInfo( - name=self.model_choice.currentText(), - model_input_size=self.model_input_size.value(), - ) - - self.weights_config.custom = self.custom_weights_choice.isChecked() - - save_path = self.results_filewidget.text_field.text() - if not self._check_results_path(save_path): - msg = f"ERROR: please set valid results path. Current path is {save_path}" - self.log.print_and_log(msg) - warnings.warn(msg) - else: - if self.results_path is None: - self.results_path = save_path - - zoom_config = config.Zoom( - enabled=self.anisotropy_wdgt.enabled(), - zoom_values=self.anisotropy_wdgt.scaling_xyz(), - ) - thresholding_config = config.Thresholding( - enabled=self.thresholding_checkbox.isChecked(), - threshold_value=self.thresholding_slider.slider_value, - ) - - self.instance_config = config.InstanceSegConfig( - enabled=self.use_instance_choice.isChecked(), - method=self.instance_widgets.methods[ - self.instance_widgets.method_choice.currentText() - ], - ) - - self.post_process_config = config.PostProcessConfig( - zoom=zoom_config, - thresholding=thresholding_config, - instance=self.instance_config, - ) - - if self.window_infer_box.isChecked(): - size = int(self.window_size_choice.currentText()) - window_config = config.SlidingWindowConfig( - window_size=size, - window_overlap=self.window_overlap_slider.slider_value, - ) - else: - window_config = config.SlidingWindowConfig() - - self.worker_config = config.InferenceWorkerConfig( - device=self.get_device(), - model_info=self.model_info, - weights_config=self.weights_config, - results_path=self.results_path, - filetype=self.filetype_choice.currentText(), - keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), - compute_stats=self.save_stats_to_csv_box.isChecked(), - post_process_config=self.post_process_config, - sliding_window_config=window_config, - ) + self._set_worker_config() ##################### ##################### ##################### @@ -607,10 +600,13 @@ def start(self): self.worker.set_download_log(self.log) self.worker.started.connect(self.on_start) + self.worker.log_signal.connect(self.log.print_and_log) self.worker.warn_signal.connect(self.log.warn) + self.worker.error_signal.connect(self.log.error) + self.worker.yielded.connect(partial(self.on_yield)) # - self.worker.errored.connect(partial(self.on_yield)) + self.worker.errored.connect(partial(self.on_error)) self.worker.finished.connect(self.on_finish) if self.get_device(show=False) == "cuda": @@ -627,6 +623,76 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _create_worker_from_config( + self, worker_config: config.InferenceWorkerConfig + ): + if isinstance(worker_config, config.InfererConfig): + raise TypeError("Please provide a valid worker config object") + return InferenceWorker(worker_config=worker_config) + + def _set_worker_config(self) -> config.InferenceWorkerConfig: + self.model_info = config.ModelInfo( + name=self.model_choice.currentText(), + model_input_size=self.model_input_size.value(), + ) + + self.weights_config.custom = self.custom_weights_choice.isChecked() + + save_path = self.results_filewidget.text_field.text() + if not self._check_results_path(save_path): + msg = f"ERROR: please set valid results path. Current path is {save_path}" + self.log.print_and_log(msg) + logger.warning(msg) + else: + if self.results_path is None: + self.results_path = save_path + + zoom_config = config.Zoom( + enabled=self.anisotropy_wdgt.enabled(), + zoom_values=self.anisotropy_wdgt.scaling_xyz(), + ) + thresholding_config = config.Thresholding( + enabled=self.thresholding_checkbox.isChecked(), + threshold_value=self.thresholding_slider.slider_value, + ) + + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.methods[ + self.instance_widgets.method_choice.currentText() + ], + ) + + self.post_process_config = config.PostProcessConfig( + zoom=zoom_config, + thresholding=thresholding_config, + instance=self.instance_config, + ) + + if self.window_infer_box.isChecked(): + size = int(self.window_size_choice.currentText()) + window_config = config.SlidingWindowConfig( + window_size=size, + window_overlap=self.window_overlap_slider.slider_value, + ) + else: + window_config = config.SlidingWindowConfig() + + self.worker_config = config.InferenceWorkerConfig( + device=self.get_device(), + model_info=self.model_info, + weights_config=self.weights_config, + results_path=self.results_path, + filetype=self.filetype_choice.currentText(), + keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), + compute_stats=self.save_stats_to_csv_box.isChecked(), + post_process_config=self.post_process_config, + sliding_window_config=window_config, + use_crf=self.use_crf.isChecked(), + crf_config=self.crf_widgets.make_config(), + ) + return self.worker_config + def on_start(self): """Catches start signal from worker to call :py:func:`~display_status_report`""" self.display_status_report() @@ -647,15 +713,18 @@ def on_start(self): self.log.print_and_log(f"Saving results to : {self.results_path}") self.log.print_and_log("Worker is running...") - def on_error(self): - """Catches errors and tries to clean up. TODO : upgrade""" + def on_error(self, error): + """Catches errors and tries to clean up.""" + self.log.print_and_log("!" * 20) self.log.print_and_log("Worker errored...") - self.log.print_and_log("Trying to clean up...") + self.log.error(error) + # self.log.print_and_log("Trying to clean up...") + self.worker.quit() self.btn_start.setText("Start") self.btn_close.setVisible(True) - self.worker = None self.worker_config = None + self.worker = None self.empty_cuda_cache() def on_finish(self): @@ -678,85 +747,112 @@ def on_yield(self, result: InferenceResult): data (dict): dict yielded by :py:func:`~inference()`, contains : "image_id" : index of the returned image, "original" : original volume used for inference, "result" : inference result widget (QWidget): widget for accessing attributes """ + + if isinstance(result, Exception): + self.on_error(result) + # raise result # viewer, progress, show_res, show_res_number, zoon, show_original # check that viewer checkbox is on and that max number of displays has not been reached. # widget.log.print_and_log(result) + try: + image_id = result.image_id + model_name = result.model_name + if self.worker_config.images_filepaths is not None: + total = len(self.worker_config.images_filepaths) + else: + total = 1 - image_id = result.image_id - model_name = result.model_name - if self.worker_config.images_filepaths is not None: - total = len(self.worker_config.images_filepaths) - else: - total = 1 + viewer = self._viewer - viewer = self._viewer + pbar_value = image_id // total + if pbar_value == 0: + pbar_value = 1 - pbar_value = image_id // total - if pbar_value == 0: - pbar_value = 1 + self.progress.setValue(100 * pbar_value) - self.progress.setValue(100 * pbar_value) + if ( + self.config.show_results + and image_id <= self.config.show_results_count + ): + zoom = self.worker_config.post_process_config.zoom.zoom_values - if ( - self.config.show_results - and image_id <= self.config.show_results_count - ): - zoom = self.worker_config.post_process_config.zoom.zoom_values + viewer.dims.ndisplay = 3 + viewer.scale_bar.visible = True + + if self.config.show_original and result.original is not None: + viewer.add_image( + result.original, + colormap="inferno", + name=f"original_{image_id}", + scale=zoom, + opacity=0.7, + ) - viewer.dims.ndisplay = 3 - viewer.scale_bar.visible = True + out_colormap = "twilight" + if self.worker_config.post_process_config.thresholding.enabled: + out_colormap = "turbo" - if self.config.show_original and result.original is not None: viewer.add_image( - result.original, - colormap="inferno", - name=f"original_{image_id}", - scale=zoom, - opacity=0.7, + result.result, + colormap=out_colormap, + name=f"pred_{image_id}_{model_name}", + opacity=0.8, ) + if result.crf_results is not None: + logger.debug( + f"CRF results shape : {result.crf_results.shape}" + ) + viewer.add_image( + result.crf_results, + name=f"CRF_results_image_{image_id}", + colormap="viridis", + ) + if ( + result.instance_labels is not None + and self.worker_config.post_process_config.instance.enabled + ): + method_name = ( + self.worker_config.post_process_config.instance.method.name + ) - out_colormap = "twilight" - if self.worker_config.post_process_config.thresholding.enabled: - out_colormap = "turbo" - - viewer.add_image( - result.result, - colormap=out_colormap, - name=f"pred_{image_id}_{model_name}", - opacity=0.8, - ) + number_cells = ( + np.unique(result.instance_labels.flatten()).size - 1 + ) # remove background - if result.instance_labels is not None: - labels = result.instance_labels - method_name = ( - self.worker_config.post_process_config.instance.method.name - ) + name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" - number_cells = ( - np.unique(labels.flatten()).size - 1 - ) # remove background + viewer.add_labels(result.instance_labels, name=name) - name = f"({number_cells} objects)_{method_name}_instance_labels_{image_id}" + from napari_cellseg3d.utils import LOGGER as log - viewer.add_labels(labels, name=name) + if result.stats is not None and isinstance( + result.stats, list + ): + log.debug(f"len stats : {len(result.stats)}") - stats = result.stats + for i, stats in enumerate(result.stats): + # stats = result.stats - if self.worker_config.compute_stats and stats is not None: - stats_dict = stats.get_dict() - stats_df = pd.DataFrame(stats_dict) + if ( + self.worker_config.compute_stats + and stats is not None + ): + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) - self.log.print_and_log( - f"Number of instances : {stats.number_objects}" - ) + self.log.print_and_log( + f"Number of instances in channel {i} : {stats.number_objects[0]}" + ) - csv_name = f"/{method_name}_seg_results_{image_id}_{utils.get_date_time()}.csv" - stats_df.to_csv( - self.worker_config.results_path + csv_name, - index=False, - ) + csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv" + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, + ) - # self.log.print_and_log( - # f"OBJECTS DETECTED : {number_cells}\n" - # ) + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" + # ) + except Exception as e: + self.on_error(e) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index cf8e4b85..3e666dcc 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1,10 +1,9 @@ import shutil -import warnings from functools import partial from pathlib import Path +from typing import TYPE_CHECKING import matplotlib.pyplot as plt -import napari import numpy as np import pandas as pd import torch @@ -13,6 +12,9 @@ ) from matplotlib.figure import Figure +if TYPE_CHECKING: + import napari + # MONAI from monai.losses import ( DiceCELoss, @@ -30,7 +32,7 @@ from napari_cellseg3d import config, utils from napari_cellseg3d import interface as ui from napari_cellseg3d.code_models.model_framework import ModelFramework -from napari_cellseg3d.code_models.model_workers import ( +from napari_cellseg3d.code_models.workers import ( TrainingReport, TrainingWorker, ) @@ -46,6 +48,8 @@ class Trainer(ModelFramework, metaclass=ui.QWidgetSingleton): Features parameter selection for training, dynamic loss plotting and automatic saving of the best weights during training through validation.""" + default_config = config.TrainingWorkerConfig() + def __init__( self, viewer: "napari.viewer.Viewer", @@ -165,17 +169,18 @@ def __init__( self.validation_values = [] # self.model_choice.setCurrentIndex(0) + wnet_index = self.model_choice.findText("WNet") + self.model_choice.removeItem(wnet_index) ################################ # interface - default = config.TrainingWorkerConfig() self.zip_choice = ui.CheckBox("Compress results") self.validation_percent_choice = ui.Slider( lower=10, upper=90, - default=default.validation_percent * 100, + default=self.default_config.validation_percent * 100, step=5, parent=self, ) @@ -183,12 +188,12 @@ def __init__( self.epoch_choice = ui.IntIncrementCounter( lower=2, upper=200, - default=default.max_epochs, - label="Number of epochs : ", + default=self.default_config.max_epochs, + text_label="Number of epochs : ", ) self.loss_choice = ui.DropdownMenu( - sorted(self.loss_dict.keys()), label="Loss function" + sorted(self.loss_dict.keys()), text_label="Loss function" ) self.lbl_loss_choice = self.loss_choice.label self.loss_choice.setCurrentIndex(0) @@ -196,7 +201,7 @@ def __init__( self.sample_choice_slider = ui.Slider( lower=2, upper=50, - default=default.num_samples, + default=self.default_config.num_samples, text_label="Number of patches per image : ", ) @@ -205,13 +210,13 @@ def __init__( self.batch_choice = ui.Slider( lower=1, upper=10, - default=default.batch_size, + default=self.default_config.batch_size, text_label="Batch size : ", ) self.val_interval_choice = ui.IntIncrementCounter( - default=default.validation_interval, - label="Validation interval : ", + default=self.default_config.validation_interval, + text_label="Validation interval : ", ) self.epoch_choice.valueChanged.connect(self._update_validation_choice) @@ -228,12 +233,24 @@ def __init__( ] self.learning_rate_choice = ui.DropdownMenu( - learning_rate_vals, label="Learning rate" + learning_rate_vals, text_label="Learning rate" ) self.lbl_learning_rate_choice = self.learning_rate_choice.label self.learning_rate_choice.setCurrentIndex(1) + self.scheduler_patience_choice = ui.IntIncrementCounter( + 1, + 99, + default=self.default_config.scheduler_patience, + text_label="Scheduler patience", + ) + self.scheduler_factor_choice = ui.Slider( + divide_factor=100, + default=self.default_config.scheduler_factor * 100, + text_label="Scheduler factor :", + ) + self.augment_choice = ui.CheckBox("Augment data") self.close_buttons = [ @@ -268,7 +285,8 @@ def __init__( "Deterministic training", func=self._toggle_deterministic_param ) self.box_seed = ui.IntIncrementCounter( - upper=10000000, default=default.deterministic_config.seed + upper=10000000, + default=self.default_config.deterministic_config.seed, ) self.lbl_seed = ui.make_label("Seed", self) self.container_seed = ui.combine_blocks( @@ -309,6 +327,12 @@ def set_tooltips(): self.learning_rate_choice.setToolTip( "The learning rate to use in the optimizer. \nUse a lower value if you're using pre-trained weights" ) + self.scheduler_factor_choice.setToolTip( + "The factor by which to reduce the learning rate once the loss reaches a plateau" + ) + self.scheduler_patience_choice.setToolTip( + "The amount of epochs to wait for before reducing the learning rate" + ) self.augment_choice.setToolTip( "Check this to enable data augmentation, which will randomly deform, flip and shift the intensity in images" " to provide a more general dataset. \nUse this if you're extracting more than 10 samples per image" @@ -395,12 +419,10 @@ def check_ready(self): * False and displays a warning if not """ - if self.images_filepaths != [] and self.labels_filepaths != []: - return True - else: - warnings.formatwarning = utils.format_Warning - warnings.warn("Image and label paths are not correctly set") + if self.images_filepaths == [] and self.labels_filepaths != []: + logger.warning("Image and label paths are not correctly set") return False + return True def _build(self): """Builds the layout of the widget and creates the following tabs and prompts: @@ -632,26 +654,20 @@ def _build(self): "Training parameters", r=1, b=5, t=11 ) - spacing = 20 - ui.add_widgets( train_param_group_l, [ self.batch_choice.container, # batch size - ui.combine_blocks( - self.learning_rate_choice, - self.lbl_learning_rate_choice, - min_spacing=spacing, - horizontal=False, - l=5, - t=5, - r=5, - b=5, - ), # learning rate + self.lbl_learning_rate_choice, + self.learning_rate_choice, self.epoch_choice.label, # epochs self.epoch_choice, self.val_interval_choice.label, self.val_interval_choice, # validation interval + self.scheduler_patience_choice.label, + self.scheduler_patience_choice, + self.scheduler_factor_choice.label, + self.scheduler_factor_choice.container, ], None, ) @@ -773,7 +789,7 @@ def start(self): if not self.check_ready(): # issues a warning if not ready err = "Aborting, please set all required paths" self.log.print_and_log(err) - warnings.warn(err) + logger.warning(err) return if self.worker is not None: @@ -794,64 +810,12 @@ def start(self): self.data = None raise err - model_config = config.ModelInfo( - name=self.model_choice.currentText() - ) - - self.weights_config.path = self.weights_config.path - self.weights_config.custom = self.custom_weights_choice.isChecked() - self.weights_config.use_pretrained = ( - not self.use_transfer_choice.isChecked() - ) - - deterministic_config = config.DeterministicConfig( - enabled=self.use_deterministic_choice.isChecked(), - seed=self.box_seed.value(), - ) - - validation_percent = ( - self.validation_percent_choice.slider_value / 100 - ) - - results_path_folder = Path( - self.results_path - + f"/{model_config.name}_{utils.get_date_time()}" - ) - Path(results_path_folder).mkdir( - parents=True, exist_ok=False - ) # avoid overwrite where possible - - patch_size = [w.value() for w in self.patch_size_widgets] - - logger.debug("Loading config...") - self.worker_config = config.TrainingWorkerConfig( - device=self.get_device(), - model_info=model_config, - weights_info=self.weights_config, - train_data_dict=self.data, - validation_percent=validation_percent, - max_epochs=self.epoch_choice.value(), - loss_function=self.get_loss(self.loss_choice.currentText()), - learning_rate=float(self.learning_rate_choice.currentText()), - validation_interval=self.val_interval_choice.value(), - batch_size=self.batch_choice.slider_value, - results_path_folder=str(results_path_folder), - sampling=self.patch_choice.isChecked(), - num_samples=self.sample_choice_slider.slider_value, - sample_size=patch_size, - do_augmentation=self.augment_choice.isChecked(), - deterministic_config=deterministic_config, - ) # TODO(cyril) continue to put params in config - self.config = config.TrainerConfig( save_as_zip=self.zip_choice.isChecked() ) + self._set_worker_config() - self.log.print_and_log( - f"Saving results to : {results_path_folder}" - ) - - self.worker = TrainingWorker(config=self.worker_config) + self.worker = TrainingWorker(worker_config=self.worker_config) self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] @@ -879,6 +843,64 @@ def start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _create_worker_from_config( + self, worker_config: config.TrainingWorkerConfig + ): + if isinstance(config, config.TrainerConfig): + raise TypeError( + "Expected a TrainingWorkerConfig, got a TrainerConfig" + ) + return TrainingWorker(worker_config=worker_config) + + def _set_worker_config(self) -> config.TrainingWorkerConfig: + model_config = config.ModelInfo(name=self.model_choice.currentText()) + + self.weights_config.path = self.weights_config.path + self.weights_config.custom = self.custom_weights_choice.isChecked() + self.weights_config.use_pretrained = ( + not self.use_transfer_choice.isChecked() + ) + + deterministic_config = config.DeterministicConfig( + enabled=self.use_deterministic_choice.isChecked(), + seed=self.box_seed.value(), + ) + + validation_percent = self.validation_percent_choice.slider_value / 100 + + results_path_folder = Path( + self.results_path + f"/{model_config.name}_{utils.get_date_time()}" + ) + Path(results_path_folder).mkdir( + parents=True, exist_ok=False + ) # avoid overwrite where possible + + patch_size = [w.value() for w in self.patch_size_widgets] + + logger.debug("Loading config...") + self.worker_config = config.TrainingWorkerConfig( + device=self.get_device(), + model_info=model_config, + weights_info=self.weights_config, + train_data_dict=self.data, + validation_percent=validation_percent, + max_epochs=self.epoch_choice.value(), + loss_function=self.get_loss(self.loss_choice.currentText()), + learning_rate=float(self.learning_rate_choice.currentText()), + scheduler_patience=self.scheduler_patience_choice.value(), + scheduler_factor=self.scheduler_factor_choice.slider_value, + validation_interval=self.val_interval_choice.value(), + batch_size=self.batch_choice.slider_value, + results_path_folder=str(results_path_folder), + sampling=self.patch_choice.isChecked(), + num_samples=self.sample_choice_slider.slider_value, + sample_size=patch_size, + do_augmentation=self.augment_choice.isChecked(), + deterministic_config=deterministic_config, + ) # TODO(cyril) continue to put params in config + + return self.worker_config + def on_start(self): """Catches started signal from worker""" @@ -968,7 +990,7 @@ def on_yield(self, report: TrainingReport): layer = self._viewer.add_image( report.images[i], name=layer_name + str(i), - colormap="twilight", + colormap="viridis", ) self.result_layers.append(layer) else: @@ -979,13 +1001,13 @@ def on_yield(self, report: TrainingReport): new_layer = self._viewer.add_image( report.images[i], name=layer_name + str(i), - colormap="twilight", + colormap="viridis", ) self.result_layers.append(new_layer) self.result_layers[i].data = report.images[i] self.result_layers[i].refresh() except Exception as e: - logger.error(e) + logger.exception(e) self.progress.setValue( 100 * (report.epoch + 1) // self.worker_config.max_epochs @@ -1027,7 +1049,7 @@ def _make_csv(self): size_column = range(1, self.worker_config.max_epochs + 1) if len(self.loss_values) == 0 or self.loss_values is None: - warnings.warn("No loss values to add to csv !") + logger.warning("No loss values to add to csv !") return self.df = pd.DataFrame( @@ -1117,7 +1139,7 @@ def update_loss_plot(self, loss, metric): epoch = len(loss) if epoch < self.worker_config.validation_interval * 2: return - elif epoch == self.worker_config.validation_interval * 2: + if epoch == self.worker_config.validation_interval * 2: bckgrd_color = (0, 0, 0, 0) # '#262930' with plt.style.context("dark_background"): self.canvas = FigureCanvas(Figure(figsize=(10, 1.5))) @@ -1153,7 +1175,7 @@ def update_loss_plot(self, loss, metric): ) self.plot_dock._close_btn = False except AttributeError as e: - logger.error(e) + logger.exception(e) logger.error( "Plot dock widget could not be added. Should occur in testing only" ) diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py index 7ed6c549..dd98bcd7 100644 --- a/napari_cellseg3d/code_plugins/plugin_review.py +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path import matplotlib.pyplot as plt @@ -20,7 +19,6 @@ from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager -warnings.formatwarning = utils.format_Warning logger = utils.LOGGER @@ -181,7 +179,7 @@ def check_image_data(self): raise ValueError("Review requires at least one image") if cfg.labels is not None and cfg.image.shape != cfg.labels.shape: - warnings.warn( + logger.warning( "Image and label dimensions do not match ! Please load matching images" ) @@ -237,7 +235,7 @@ def run_review(self): self._reset() previous_viewer.close() except ValueError as e: - warnings.warn( + logger.warning( f"An exception occurred : {e}. Please ensure you have entered all required parameters." ) @@ -401,7 +399,7 @@ def update_canvas_canvas(viewer, event): ) canvas.draw_idle() except Exception as e: - logger.error(e) + logger.exception(e) # Qt widget defined in docker.py dmg = Datamanager(parent=viewer) diff --git a/napari_cellseg3d/code_plugins/plugin_review_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py index c09c376f..f634d117 100644 --- a/napari_cellseg3d/code_plugins/plugin_review_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,10 +1,12 @@ -import warnings from datetime import datetime, timedelta from pathlib import Path +from typing import TYPE_CHECKING -import napari import pandas as pd +if TYPE_CHECKING: + import napari + # Qt from qtpy.QtWidgets import QVBoxLayout, QWidget @@ -16,7 +18,7 @@ GUI_MINIMUM_HEIGHT = 300 TIMER_FORMAT = "%H:%M:%S" - +logger = utils.LOGGER """ plugin_dock.py ==================================== @@ -261,7 +263,7 @@ def update_dm(self, slice_num): def button_func(self): # updates csv every time you press button... if self.viewer.dims.ndisplay != 2: # TODO test if undefined behaviour or if okay - warnings.warn("Please switch back to 2D mode !") + logger.warning("Please switch back to 2D mode !") return self.update_time_csv() diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 5463a4ff..6e1a606a 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -1,4 +1,7 @@ -import napari +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import napari # Qt from qtpy.QtCore import qInstallMessageHandler @@ -13,6 +16,7 @@ ToInstanceUtils, ToSemanticUtils, ) +from napari_cellseg3d.code_plugins.plugin_crf import CRFWidget from napari_cellseg3d.code_plugins.plugin_crop import Cropping UTILITIES_WIDGETS = { @@ -22,6 +26,7 @@ "Convert to instance labels": ToInstanceUtils, "Convert to semantic labels": ToSemanticUtils, "Threshold": ThresholdUtils, + "CRF": CRFWidget, } @@ -30,7 +35,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): super().__init__() self._viewer = viewer - attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh"] + attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh", "crf"] self._create_utils_widgets(attr_names) # self.crop = Cropping(self._viewer) @@ -41,7 +46,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): # self.small = RemoveSmallUtils(self._viewer) self.utils_choice = ui.DropdownMenu( - UTILITIES_WIDGETS.keys(), label="Utilities" + UTILITIES_WIDGETS.keys(), text_label="Utilities" ) self._build() @@ -54,8 +59,15 @@ def __init__(self, viewer: "napari.viewer.Viewer"): def _build(self): layout = QVBoxLayout() ui.add_widgets(layout, self.utils_widgets) - layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) - layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) + ui.GroupedWidget.create_single_widget_group( + "Utilities", + widget=self.utils_choice, + layout=layout, + alignment=ui.BOTT_AL, + ) + + # layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) + # layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) # layout.setSizeConstraint(QLayout.SetFixedSize) self.setLayout(layout) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 737b53aa..af42d779 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -1,5 +1,4 @@ import datetime -import warnings from dataclasses import dataclass from pathlib import Path from typing import List, Optional @@ -7,15 +6,14 @@ import napari import numpy as np -from napari_cellseg3d.code_models.model_instance_seg import InstanceMethod +from napari_cellseg3d.code_models.instance_segmentation import InstanceMethod # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP -from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet -from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR -from napari_cellseg3d.code_models.models import ( - model_TRAILMAP_MS as TRAILMAP_MS, -) -from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.code_models.models.model_SegResNet import SegResNet_ +from napari_cellseg3d.code_models.models.model_SwinUNetR import SwinUNETR_ +from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ +from napari_cellseg3d.code_models.models.model_VNet import VNet_ +from napari_cellseg3d.code_models.models.model_WNet import WNet_ from napari_cellseg3d.utils import LOGGER logger = LOGGER @@ -24,16 +22,16 @@ # TODO(cyril) add JSON load/save MODEL_LIST = { - "SegResNet": SegResNet, - "VNet": VNet, + "SegResNet": SegResNet_, + "VNet": VNet_, # "TRAILMAP": TRAILMAP, - "TRAILMAP_MS": TRAILMAP_MS, - "SwinUNetR": SwinUNetR, + "TRAILMAP_MS": TRAILMAP_MS_, + "SwinUNetR": SwinUNETR_, + "WNet": WNet_, # "test" : DO NOT USE, reserved for testing } - -WEIGHTS_DIR = str( +PRETRAINED_WEIGHTS_DIR = str( Path(__file__).parent.resolve() / Path("code_models/models/pretrained") ) @@ -69,30 +67,37 @@ class ReviewSession: @dataclass class ModelInfo: - """Dataclass recording model info : - - name (str): name of the model""" + """Dataclass recording model info + Args: + name (str): name of the model + model_input_size (Optional[List[int]]): input size of the model + num_classes (int): number of classes for the model + """ name: str = next(iter(MODEL_LIST)) model_input_size: Optional[List[int]] = None + num_classes: int = 2 def get_model(self): try: return MODEL_LIST[self.name] except KeyError as e: msg = f"Model {self.name} is not defined" - warnings.warn(msg) logger.warning(msg) - raise KeyError(e) + logger.warning(msg) + raise KeyError from e @staticmethod def get_model_name_list(): - logger.info("Model list :\n" + str(f"{name}\n" for name in MODEL_LIST)) + logger.info("Model list :") + for model_name in MODEL_LIST: + logger.info(f" * {model_name}") return MODEL_LIST.keys() @dataclass class WeightsInfo: - path: str = WEIGHTS_DIR + path: str = PRETRAINED_WEIGHTS_DIR custom: bool = False use_pretrained: Optional[bool] = False @@ -121,11 +126,33 @@ class InstanceSegConfig: @dataclass class PostProcessConfig: + """Class to record params for post processing + + Args: + zoom (Zoom): zoom config + thresholding (Thresholding): thresholding config + instance (InstanceSegConfig): instance segmentation config + """ + zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() instance: InstanceSegConfig = InstanceSegConfig() +@dataclass +class CRFConfig: + """ + Class to record params for CRF + """ + + sa: float = 10 + sb: float = 5 + sg: float = 1 + w1: float = 10 + w2: float = 5 + n_iters: int = 5 + + ################ # Inference configs @@ -141,7 +168,15 @@ def is_enabled(self): @dataclass class InfererConfig: - """Class to record params for Inferer plugin""" + """Class to record params for Inferer plugin + + Args: + model_info (ModelInfo): model info + show_results (bool): show results in napari + show_results_count (int): number of results to show + show_original (bool): show original image in napari + anisotropy_resolution (List[int]): anisotropy resolution + """ model_info: ModelInfo = None show_results: bool = False @@ -152,7 +187,21 @@ class InfererConfig: @dataclass class InferenceWorkerConfig: - """Class to record configuration for Inference job""" + """Class to record configuration for Inference job + + Args: + device (str): device to use for inference + model_info (ModelInfo): model info + weights_config (WeightsInfo): weights info + results_path (str): path to save results + filetype (str): filetype to save results + keep_on_cpu (bool): keep results on cpu + compute_stats (bool): compute stats + post_process_config (PostProcessConfig): post processing config + sliding_window_config (SlidingWindowConfig): sliding window config + images_filepaths (str): path to images to infer + layer (napari.layers.Layer): napari layer to infer on + """ device: str = "cpu" model_info: ModelInfo = ModelInfo() @@ -163,6 +212,8 @@ class InferenceWorkerConfig: compute_stats: bool = False post_process_config: PostProcessConfig = PostProcessConfig() sliding_window_config: SlidingWindowConfig = SlidingWindowConfig() + use_crf: bool = False + crf_config: CRFConfig = CRFConfig() images_filepaths: str = None layer: napari.layers.Layer = None @@ -199,6 +250,8 @@ class TrainingWorkerConfig: max_epochs: int = 5 loss_function: callable = None learning_rate: np.float64 = 1e-3 + scheduler_patience: int = 10 + scheduler_factor: float = 0.5 validation_interval: int = 2 batch_size: int = 1 results_path_folder: str = str(Path.home() / Path("cellseg3d/training")) @@ -207,3 +260,21 @@ class TrainingWorkerConfig: sample_size: List[int] = None do_augmentation: bool = True deterministic_config: DeterministicConfig = DeterministicConfig() + + +################ +# CRF config for WNet +################ + + +@dataclass +class WNetCRFConfig: + "Class to store parameters of WNet CRF post processing" + + # CRF + sa = 10 # 50 + sb = 10 + sg = 1 + w1 = 10 # 50 + w2 = 10 + n_iter = 5 diff --git a/napari_cellseg3d/dev_scripts/artefact_labeling.py b/napari_cellseg3d/dev_scripts/artefact_labeling.py index 3f95e1a8..93746eb6 100644 --- a/napari_cellseg3d/dev_scripts/artefact_labeling.py +++ b/napari_cellseg3d/dev_scripts/artefact_labeling.py @@ -1,4 +1,5 @@ -import os +import os # TODO(cyril): remove os +from pathlib import Path import napari import numpy as np @@ -6,11 +7,12 @@ from skimage.filters import threshold_otsu from tifffile import imread, imwrite -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + """ New code by Yves Paychere Creates labels of artifacts in an image based on existing labels of neurons @@ -76,7 +78,7 @@ def make_labels( Parameters ---------- image : str - image array + Path to image. path_labels_out : str Path of the output labelled image. threshold_size : int, optional @@ -95,7 +97,7 @@ def make_labels( Label image with nucleus labelled with 1 value per nucleus. """ - image = imread(image) + # image = imread(image) image = (image - np.min(image)) / (np.max(image) - np.min(image)) threshold_brightness = threshold_otsu(image) * threshold_factor @@ -105,7 +107,6 @@ def make_labels( image_contrasted = (image_contrasted - np.min(image_contrasted)) / ( np.max(image_contrasted) - np.min(image_contrasted) ) - image_contrasted = image_contrasted * augment_contrast_factor image_contrasted = np.where(image_contrasted > 1, 1, image_contrasted) labels = binary_watershed(image_contrasted, thres_small=threshold_size) @@ -141,7 +142,6 @@ def select_image_by_labels(image, labels, path_image_out, label_values): """ # image = imread(image) # labels = imread(labels) - image = np.where(np.isin(labels, label_values), image, 0) imwrite(path_image_out, image.astype(np.float32)) @@ -290,18 +290,13 @@ def select_artefacts_by_size(artefacts, min_size, is_labeled=False): ndarray Label image with artefacts labelled and small artefacts removed. """ - if not is_labeled: - # find all the connected components in the artefacts image - labels = ndimage.label(artefacts)[0] - else: - labels = artefacts + labels = ndimage.label(artefacts)[0] if not is_labeled else artefacts # remove the small components labels_i, counts = np.unique(labels, return_counts=True) labels_i = labels_i[counts > min_size] labels_i = labels_i[labels_i > 0] - artefacts = np.where(np.isin(labels, labels_i), labels, 0) - return artefacts + return np.where(np.isin(labels, labels_i), labels, 0) def create_artefact_labels( @@ -389,7 +384,7 @@ def create_artefact_labels_from_folder( path_labels.sort() path_images.sort() # create the output folder - os.makedirs(path + "/artefact_neurons", exist_ok=True) + Path().mkdir(path + "/artefact_neurons", exist_ok=True) # create the artefact labels for i in range(len(path_images)): print(path_labels[i]) diff --git a/napari_cellseg3d/dev_scripts/convert.py b/napari_cellseg3d/dev_scripts/convert.py deleted file mode 100644 index 641de627..00000000 --- a/napari_cellseg3d/dev_scripts/convert.py +++ /dev/null @@ -1,26 +0,0 @@ -import glob -import os - -import numpy as np -from tifffile import imread, imwrite - -# input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab" -# output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/pytorch-test3dunet/cropped_visual/train/lab_sem" - -input_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/cellseg-annotator-test/napari_cellseg3d/models/dataset/labels" -output_seg_path = "C:/Users/Cyril/Desktop/Proj_bachelor/code/cellseg-annotator-test/napari_cellseg3d/models/dataset/lab_sem" - -filenames = [] -paths = [] -filetype = ".tif" -for filename in glob.glob(os.path.join(input_seg_path, "*" + filetype)): - paths.append(filename) - filenames.append(os.path.basename(filename)) - # print(os.path.basename(filename)) -for file in paths: - image = imread(file) - - image[image >= 1] = 1 - image = image.astype(np.uint16) - - imwrite(output_seg_path + "/" + os.path.basename(file), image) diff --git a/napari_cellseg3d/dev_scripts/correct_labels.py b/napari_cellseg3d/dev_scripts/correct_labels.py index 168990e1..f413812d 100644 --- a/napari_cellseg3d/dev_scripts/correct_labels.py +++ b/napari_cellseg3d/dev_scripts/correct_labels.py @@ -12,10 +12,12 @@ from tqdm import tqdm import napari_cellseg3d.dev_scripts.artefact_labeling as make_artefact_labels -from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.instance_segmentation import binary_watershed # import sys # sys.path.append(str(Path(__file__) / "../../")) + + """ New code by Yves Paychère Fixes labels and allows to auto-detect artifacts and neurons based on a simple intenstiy threshold @@ -361,9 +363,9 @@ def relabel_non_unique_i_folder(folder_path, end_of_new_name="relabeled"): ) -# if __name__ == "__main__": -# im_path = Path("C:/Users/Cyril/Desktop/test/instance_test") -# image_path = str(im_path / "image.tif") -# gt_labels_path = str(im_path / "labels.tif") -# -# relabel(image_path, gt_labels_path, check_for_unicity=True, go_fast=False) +if __name__ == "__main__": + im_path = Path("C:/Users/Cyril/Desktop/Proj_bachelor/data/somatomotor") + + image_path = str(im_path / "volumes/c1images.tif") + gt_labels_path = str(im_path / "labels/c1labels.tif") + relabel(image_path, gt_labels_path, check_for_unicity=False, go_fast=False) diff --git a/napari_cellseg3d/dev_scripts/drafts.py b/napari_cellseg3d/dev_scripts/drafts.py deleted file mode 100644 index cdd02256..00000000 --- a/napari_cellseg3d/dev_scripts/drafts.py +++ /dev/null @@ -1,15 +0,0 @@ -import napari -import numpy as np -from magicgui import magicgui -from napari.types import ImageData, LabelsData - - -@magicgui(call_button="Run Threshold") -def threshold(image: ImageData, threshold: int = 75) -> LabelsData: - """Threshold an image and return a mask.""" - return (image > threshold).astype(int) - - -viewer = napari.view_image(np.random.randint(0, 100, (64, 64))) -viewer.window.add_dock_widget(threshold) -threshold() diff --git a/napari_cellseg3d/dev_scripts/evaluate_labels.py b/napari_cellseg3d/dev_scripts/evaluate_labels.py index bd2f0768..2830f4e7 100644 --- a/napari_cellseg3d/dev_scripts/evaluate_labels.py +++ b/napari_cellseg3d/dev_scripts/evaluate_labels.py @@ -127,7 +127,7 @@ def evaluate_model_performance( ) if visualize: - viewer = napari.Viewer() + viewer = napari.Viewer(ndisplay=3) viewer.add_labels(labels, name="ground truth") viewer.add_labels(model_labels, name="model's labels") found_model = np.where( @@ -474,193 +474,6 @@ def save_as_csv(results, path): # # return label_report_list, new_labels, map_fused_neurons -####################### -# Slower version that was used for debugging -####################### - -# from collections import Counter -# from dataclasses import dataclass -# from typing import Dict -# @dataclass -# class LabelInfo: -# gt_index: int -# model_labels_id_and_status: Dict = None # for each model label id present on gt_index in gt labels, contains status (correct/wrong) -# best_model_label_coverage: float = ( -# 0.0 # ratio of pixels of the gt label correctly labelled -# ) -# overall_gt_label_coverage: float = 0.0 # true positive ration of the model -# -# def get_correct_ratio(self): -# for model_label, status in self.model_labels_id_and_status.items(): -# if status == "correct": -# return self.best_model_label_coverage -# else: -# return None - - -# def eval_model(gt_labels, model_labels, print_report=False): -# -# report_list, new_labels, fused_labels = create_label_report( -# gt_labels, model_labels -# ) -# per_label_perfs = [] -# for report in report_list: -# if print_report: -# log.info( -# f"Label {report.gt_index} : {report.model_labels_id_and_status}" -# ) -# log.info( -# f"Best model label coverage : {report.best_model_label_coverage}" -# ) -# log.info( -# f"Overall gt label coverage : {report.overall_gt_label_coverage}" -# ) -# -# perf = report.get_correct_ratio() -# if perf is not None: -# per_label_perfs.append(perf) -# -# per_label_perfs = np.array(per_label_perfs) -# return per_label_perfs.mean(), new_labels, fused_labels - - -# def create_label_report(gt_labels, model_labels): -# """Map the model's labels to the neurons labels. -# Parameters -# ---------- -# gt_labels : ndarray -# Label image with neurons labelled as mulitple values. -# model_labels : ndarray -# Label image from the model labelled as mulitple values. -# Returns -# ------- -# map_labels_existing: numpy array -# The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label correctly labelled -# map_fused_neurons: numpy array -# The neurones are considered fused if they are labelled by the same model's label, in this case we will return The label value of the model and the label value of the neurone associated, the ratio of the pixels of the true label correctly labelled, the ratio of the pixels of the model's label that are in one of the fused neurones -# new_labels: list -# The labels of the model that are not labelled in the neurons, the ratio of the pixels of the model's label that are an artefact -# """ -# -# map_labels_existing = [] -# map_fused_neurons = {} -# "background_labels contains all model labels where gt_labels is 0 and model_labels is not 0" -# background_labels = model_labels[np.where((gt_labels == 0))] -# "new_labels contains all labels in model_labels for which more than PERCENT_CORRECT% of the pixels are not labelled in gt_labels" -# new_labels = [] -# for lab in np.unique(background_labels): -# if lab == 0: -# continue -# gt_background_size_at_lab = ( -# gt_labels[np.where((model_labels == lab) & (gt_labels == 0))] -# .flatten() -# .shape[0] -# ) -# gt_lab_size = ( -# gt_labels[np.where(model_labels == lab)].flatten().shape[0] -# ) -# if gt_background_size_at_lab / gt_lab_size > PERCENT_CORRECT: -# new_labels.append(lab) -# -# label_report_list = [] -# # label_report = {} # contains a dict saying which labels are correct or wrong for each gt label -# # model_label_values = {} # contains the model labels value assigned to each unique gt label -# not_found_id = 0 -# -# for i in tqdm(np.unique(gt_labels)): -# if i == 0: -# continue -# -# gt_label = gt_labels[np.where(gt_labels == i)] # get a single gt label -# -# model_lab_on_gt = model_labels[ -# np.where(((gt_labels == i) & (model_labels != 0))) -# ] # all models labels on single gt_label -# info = LabelInfo(i) -# -# info.model_labels_id_and_status = { -# label_id: "" for label_id in np.unique(model_lab_on_gt) -# } -# -# if model_lab_on_gt.shape[0] == 0: -# info.model_labels_id_and_status[ -# f"not_found_{not_found_id}" -# ] = "not found" -# not_found_id += 1 -# label_report_list.append(info) -# continue -# -# log.debug(f"model_lab_on_gt : {np.unique(model_lab_on_gt)}") -# -# # create LabelInfo object and init model_labels_id_and_status with all unique model labels on gt_label -# log.debug( -# f"info.model_labels_id_and_status : {info.model_labels_id_and_status}" -# ) -# -# ratio = [] -# for model_lab_id in info.model_labels_id_and_status.keys(): -# size_model_label = ( -# model_lab_on_gt[np.where(model_lab_on_gt == model_lab_id)] -# .flatten() -# .shape[0] -# ) -# size_gt_label = gt_label.flatten().shape[0] -# -# log.debug(f"size_model_label : {size_model_label}") -# log.debug(f"size_gt_label : {size_gt_label}") -# -# ratio.append(size_model_label / size_gt_label) -# -# # log.debug(ratio) -# ratio_model_lab_for_given_gt_lab = np.array(ratio) -# info.best_model_label_coverage = ratio_model_lab_for_given_gt_lab.max() -# -# best_model_lab_id = model_lab_on_gt[ -# np.argmax(ratio_model_lab_for_given_gt_lab) -# ] -# log.debug(f"best_model_lab_id : {best_model_lab_id}") -# -# info.overall_gt_label_coverage = ( -# ratio_model_lab_for_given_gt_lab.sum() -# ) # the ratio of the pixels of the true label correctly labelled -# -# if info.best_model_label_coverage > PERCENT_CORRECT: -# info.model_labels_id_and_status[best_model_lab_id] = "correct" -# # info.model_labels_id_and_size[best_model_lab_id] = model_labels[np.where(model_labels == best_model_lab_id)].flatten().shape[0] -# else: -# info.model_labels_id_and_status[best_model_lab_id] = "wrong" -# for model_lab_id in np.unique(model_lab_on_gt): -# if model_lab_id != best_model_lab_id: -# log.debug(model_lab_id, "is wrong") -# info.model_labels_id_and_status[model_lab_id] = "wrong" -# -# label_report_list.append(info) -# -# correct_labels_id = [] -# for report in label_report_list: -# for i_lab in report.model_labels_id_and_status.keys(): -# if report.model_labels_id_and_status[i_lab] == "correct": -# correct_labels_id.append(i_lab) -# """Find all labels in label_report_list that are correct more than once""" -# duplicated_labels = [ -# item for item, count in Counter(correct_labels_id).items() if count > 1 -# ] -# "Sum up the size of all duplicated labels" -# for i in duplicated_labels: -# for report in label_report_list: -# if ( -# i in report.model_labels_id_and_status.keys() -# and report.model_labels_id_and_status[i] == "correct" -# ): -# size = ( -# model_labels[np.where(model_labels == i)] -# .flatten() -# .shape[0] -# ) -# map_fused_neurons[i] = size -# -# return label_report_list, new_labels, map_fused_neurons - # if __name__ == "__main__": # """ # # Example of how to use the functions in this module. diff --git a/napari_cellseg3d/dev_scripts/thread_test.py b/napari_cellseg3d/dev_scripts/thread_test.py index 20668125..a48f6db0 100644 --- a/napari_cellseg3d/dev_scripts/thread_test.py +++ b/napari_cellseg3d/dev_scripts/thread_test.py @@ -1,8 +1,8 @@ import time import napari -import numpy as np from napari.qt.threading import thread_worker +from numpy.random import PCG64, Generator from qtpy.QtWidgets import ( QGridLayout, QLabel, @@ -13,6 +13,8 @@ QWidget, ) +rand_gen = Generator(PCG64(12345)) + @thread_worker def two_way_communication_with_args(start, end): @@ -129,7 +131,7 @@ def on_finish(): if __name__ == "__main__": - viewer = napari.view_image(np.random.rand(512, 512)) + viewer = napari.view_image(rand_gen.random(512, 512)) w = create_connected_widget(viewer) viewer.window.add_dock_widget(w) diff --git a/napari_cellseg3d/dev_scripts/view_brain.py b/napari_cellseg3d/dev_scripts/view_brain.py deleted file mode 100644 index 145d4e45..00000000 --- a/napari_cellseg3d/dev_scripts/view_brain.py +++ /dev/null @@ -1,8 +0,0 @@ -import napari -from tifffile import imread - -y = imread("/Users/maximevidal/Documents/3drawdata/wholebrain.tif") - -with napari.gui_qt(): - viewer = napari.Viewer() - viewer.add_image(y, contrast_limits=[0, 2000], multiscale=False) diff --git a/napari_cellseg3d/dev_scripts/view_sample.py b/napari_cellseg3d/dev_scripts/view_sample.py deleted file mode 100644 index 8e87f85c..00000000 --- a/napari_cellseg3d/dev_scripts/view_sample.py +++ /dev/null @@ -1,29 +0,0 @@ -import napari -from tifffile import imread - -# Visual -x = imread( - "/Users/maximevidal/Documents/trailmap/data/no-edge-validation/visual-original/volumes/images.tif" -) -y_semantic = imread( - "/Users/maximevidal/Documents/trailmap/data/testing/seg-visual1-single/image.tif" -) -y_instance = imread( - "/Users/maximevidal/Documents/trailmap/data/instance-testing/test-visual-5.tiff" -) -y_true = imread( - "/Users/maximevidal/Documents/3drawdata/visual/labels/labels.tif" -) - -# SM -# x = imread("/Users/maximevidal/Documents/trailmap/data/no-edge-validation/validation-original/volumes/c5images.tif") -# y = imread("/Users/maximevidal/Documents/trailmap/data/instance-testing/test1.tiff") -# y_true = imread("/Users/maximevidal/Documents/3drawdata/somatomotor/labels/c5labels.tif") - -with napari.gui_qt(): - viewer = napari.view_image( - x, colormap="inferno", contrast_limits=[200, 1000] - ) - viewer.add_image(y_semantic, name="semantic_predictions", opacity=0.5) - viewer.add_labels(y_instance, name="instance_predictions", seed=0.6) - viewer.add_labels(y_true, name="truth", seed=0.6) diff --git a/napari_cellseg3d/dev_scripts/weight_conversion.py b/napari_cellseg3d/dev_scripts/weight_conversion.py deleted file mode 100644 index 6cdb9c43..00000000 --- a/napari_cellseg3d/dev_scripts/weight_conversion.py +++ /dev/null @@ -1,234 +0,0 @@ -import collections -import os - -import torch - -from napari_cellseg3d.code_models.models import get_net -from napari_cellseg3d.code_models.models.unet.model import UNet3D - -# not sure this actually works when put here - - -def weight_translate(k, w): - k = key_translate(k) - if k.endswith(".weight"): - if w.dim() == 2: - w = w.t() - elif w.dim() == 1: - pass - elif w.dim() == 4: - w = w.permute(3, 2, 0, 1) - else: - assert w.dim() == 5 - w = w.permute(4, 3, 0, 1, 2) - return w - - -def key_translate(k): - k = ( - k.replace( - "conv3d/kernel:0", - "encoders.0.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization/gamma:0", - "encoders.0.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization/beta:0", - "encoders.0.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_1/kernel:0", - "encoders.0.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_1/gamma:0", - "encoders.0.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_1/beta:0", - "encoders.0.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_2/kernel:0", - "encoders.1.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_2/gamma:0", - "encoders.1.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_2/beta:0", - "encoders.1.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_3/kernel:0", - "encoders.1.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_3/gamma:0", - "encoders.1.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_3/beta:0", - "encoders.1.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_4/kernel:0", - "encoders.2.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_4/gamma:0", - "encoders.2.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_4/beta:0", - "encoders.2.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_5/kernel:0", - "encoders.2.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_5/gamma:0", - "encoders.2.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_5/beta:0", - "encoders.2.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_6/kernel:0", - "encoders.3.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_6/gamma:0", - "encoders.3.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_6/beta:0", - "encoders.3.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_7/kernel:0", - "encoders.3.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_7/gamma:0", - "encoders.3.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_7/beta:0", - "encoders.3.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_8/kernel:0", - "decoders.0.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_8/gamma:0", - "decoders.0.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_8/beta:0", - "decoders.0.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_9/kernel:0", - "decoders.0.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_9/gamma:0", - "decoders.0.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_9/beta:0", - "decoders.0.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_10/kernel:0", - "decoders.1.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_10/gamma:0", - "decoders.1.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_10/beta:0", - "decoders.1.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_11/kernel:0", - "decoders.1.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_11/gamma:0", - "decoders.1.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_11/beta:0", - "decoders.1.basic_module.SingleConv2.batchnorm.bias", - ) - .replace( - "conv3d_12/kernel:0", - "decoders.2.basic_module.SingleConv1.conv.weight", - ) - .replace( - "batch_normalization_12/gamma:0", - "decoders.2.basic_module.SingleConv1.batchnorm.weight", - ) - .replace( - "batch_normalization_12/beta:0", - "decoders.2.basic_module.SingleConv1.batchnorm.bias", - ) - .replace( - "conv3d_13/kernel:0", - "decoders.2.basic_module.SingleConv2.conv.weight", - ) - .replace( - "batch_normalization_13/gamma:0", - "decoders.2.basic_module.SingleConv2.batchnorm.weight", - ) - .replace( - "batch_normalization_13/beta:0", - "decoders.2.basic_module.SingleConv2.batchnorm.bias", - ) - .replace("conv3d_14/kernel:0", "final_conv.weight") - .replace("conv3d_14/bias:0", "final_conv.bias") - ) - return k - - -model = get_net() -base_path = os.path.abspath(__file__ + "/..") -weights_path = base_path + "/data/model-weights/trailmap_model.hdf5" -model.load_weights(weights_path) - -for i, l in enumerate(model.layers): - print(i, l) - print( - "L{}: {}".format( - i, ", ".join(str(w.shape) for w in model.layers[i].weights) - ) - ) - -weights_pt = collections.OrderedDict( - [(w.name, torch.from_numpy(w.numpy())) for w in model.trainable_variables] -) -torch.save(weights_pt, base_path + "/data/model-weights/trailmaptorch.pt") -torch_weights = torch.load(base_path + "/data/model-weights/trailmaptorch.pt") -param_dict = { - key_translate(k): weight_translate(k, v) for k, v in torch_weights.items() -} - -trailmap_model = UNet3D(1, 1) -torchparam = trailmap_model.state_dict() -for k, v in torchparam.items(): - print("{:20s} {}".format(k, v.shape)) - -trailmap_model.load_state_dict(param_dict, strict=False) -torch.save( - trailmap_model.state_dict(), - base_path + "/data/model-weights/trailmaptorchpretrained.pt", -) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 276f9214..3c43c81d 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,5 +1,5 @@ +import contextlib import threading -import warnings from functools import partial from typing import List, Optional @@ -8,9 +8,12 @@ # Qt # from qtpy.QtCore import QtWarningMsg from qtpy import QtCore + +# from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QObject, Qt, QUrl from qtpy.QtGui import QCursor, QDesktopServices, QTextCursor from qtpy.QtWidgets import ( + QAbstractSpinBox, QCheckBox, QComboBox, QDoubleSpinBox, @@ -57,6 +60,8 @@ """Alias for Qt.AlignmentFlag.AlignAbsolute, to use in addWidget""" BOTT_AL = Qt.AlignmentFlag.AlignBottom """Alias for Qt.AlignmentFlag.AlignBottom, to use in addWidget""" +TOP_AL = Qt.AlignmentFlag.AlignTop +"""Alias for Qt.AlignmentFlag.AlignTop, to use in addWidget""" ############### # colors dark_red = "#72071d" # crimson red @@ -65,6 +70,9 @@ napari_param_grey = "#414851" # napari parameters menu color (lighter gray) napari_param_darkgrey = "#202228" # napari default LineEdit color ############### +# dimensions for utils ScrollArea +UTILS_MAX_WIDTH = 300 +UTILS_MAX_HEIGHT = 500 logger = utils.LOGGER @@ -99,12 +107,12 @@ def __call__(cls, *args, **kwargs): ################## -def handle_adjust_errors(widget, type, context, msg: str): +def handle_adjust_errors(widget, warning_type, context, msg: str): """Qt message handler that attempts to react to errors when setting the window size and resizes the main window""" pass # head = msg.split(": ")[0] - # if type == QtWarningMsg and head == "QWindowsWindow::setGeometry": + # if warning_type == QtWarningMsg and head == "QWindowsWindow::setGeometry": # logger.warning( # f"Qt resize error : {msg}\nhas been handled by attempting to resize the window" # ) @@ -285,10 +293,26 @@ def print_and_log(self, text, printing=True): self.lock.release() def warn(self, warning): - """Show warnings.warn from another thread""" + """Show logger.warning from another thread""" self.lock.acquire() try: - warnings.warn(warning) + logger.warning(warning) + finally: + self.lock.release() + + def error(self, error, msg=None): + """Show exception and message from another thread""" + self.lock.acquire() + try: + logger.error(error, exc_info=True) + if msg is not None: + self.print_and_log(f"{msg} : {error}", printing=False) + else: + self.print_and_log( + f"Excepetion caught in another thread : {error}", + printing=False, + ) + raise error finally: self.lock.release() @@ -311,8 +335,7 @@ def toggle_visibility(checkbox, widget): def add_label(widget, label, label_before=True, horizontal=True): if label_before: return combine_blocks(widget, label, horizontal=horizontal) - else: - return combine_blocks(label, widget, horizontal=horizontal) + return combine_blocks(label, widget, horizontal=horizontal) class ContainerWidget(QWidget): @@ -389,21 +412,21 @@ def __init__( self, entries: Optional[list] = None, parent: Optional[QWidget] = None, - label: Optional[str] = None, + text_label: Optional[str] = None, fixed: Optional[bool] = True, ): """Args: entries (array(str)): Entries to add to the dropdown menu. Defaults to None, no entries if None parent (QWidget): parent QWidget to add dropdown menu to. Defaults to None, no parent is set if None - label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well + text_label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well fixed (bool): if True, will set the size policy of the dropdown menu to Fixed in h and w. Defaults to True. """ super().__init__(parent) self.label = None if entries is not None: self.addItems(entries) - if label is not None: - self.label = QLabel(label) + if text_label is not None: + self.label = QLabel(text_label) if fixed: self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) @@ -448,15 +471,21 @@ def __init__( ): super().__init__(orientation, parent) + if upper <= lower: + raise ValueError( + "The minimum value cannot be below the maximum one" + ) + self.setMaximum(upper) self.setMinimum(lower) self.setSingleStep(step) self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - self.text_label = None + self.label = None self.container = ContainerWidget( - # parent=self.parent + # parent=self.parent, + b=0, ) self._divide_factor = divide_factor @@ -479,7 +508,7 @@ def __init__( ) if text_label is not None: - self.text_label = make_label(text_label, parent=self) + self.label = make_label(text_label, parent=self) if default < lower: self._warn_outside_bounds(default) @@ -498,14 +527,14 @@ def __init__( def set_visibility(self, visible: bool): self.container.setVisible(visible) self.setVisible(visible) - self.text_label.setVisible(visible) + self.label.setVisible(visible) def _build_container(self): - if self.text_label is not None: + if self.label is not None: add_widgets( self.container.layout, [ - self.text_label, + self.label, combine_blocks(self._value_label, self, b=0), ], ) @@ -516,29 +545,35 @@ def _build_container(self): ) def _warn_outside_bounds(self, default): - warnings.warn( + logger.warning( f"Default value {default} was outside of the ({self.minimum()}:{self.maximum()}) range" ) def _update_slider(self): """Update slider when value is changed""" - if self._value_label.text() == "": - return + try: + if self._value_label.text() == "": + return - value = float(self._value_label.text()) * self._divide_factor + value = float(self._value_label.text()) * self._divide_factor - if value < self.minimum(): - self.slider_value = self.minimum() - return - if value > self.maximum(): - self.slider_value = self.maximum() - return + if value < self.minimum(): + self.slider_value = self.minimum() + return + if value > self.maximum(): + self.slider_value = self.maximum() + return - self.slider_value = value + self.slider_value = value + except Exception as e: + logger.error(e) def _update_value_label(self): """Update label, to connect to when slider is dragged""" - self._value_label.setText(str(self.value_text)) + try: + self._value_label.setText(str(self.value_text)) + except Exception as e: + logger.error(e) @property def tooltips(self): @@ -549,8 +584,8 @@ def tooltips(self, tooltip: str): self.setToolTip(tooltip) self._value_label.setToolTip(tooltip) - if self.text_label is not None: - self.text_label.setToolTip(tooltip) + if self.label is not None: + self.label.setToolTip(tooltip) @property def slider_value(self): @@ -561,7 +596,7 @@ def slider_value(self): try: return self.value() / self._divide_factor except ZeroDivisionError as e: - raise ZeroDivisionError( + raise ZeroDivisionError from ( f"Divide factor cannot be 0 for Slider : {e}" ) @@ -574,16 +609,21 @@ def value_text(self): def slider_value(self, value: int): """Set a value (int) divided by self._divide_factor""" if value < self.minimum() or value > self.maximum(): - raise ValueError( - f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + logger.error( + ValueError( + f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + ) ) - self.setValue(int(value)) + try: + self.setValue(int(value)) - divided = value / self._divide_factor - if self._divide_factor == 1.0: - divided = int(divided) - self._value_label.setText(str(divided)) + divided = value / self._divide_factor + if self._divide_factor == 1.0: + divided = int(divided) + self._value_label.setText(str(divided)) + except Exception as e: + logger.error(e) class AnisotropyWidgets(QWidget): @@ -628,7 +668,7 @@ def __init__( w.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) self.box_widgets_lbl = [ - make_label("Resolution in " + axis + " (microns) :", parent=parent) + make_label("Pixel size in " + axis + " (microns) :", parent=parent) for axis in "xyz" ] @@ -696,9 +736,8 @@ def anisotropy_zoom_factor(aniso_res): """ - base = min(aniso_res) - zoom_factors = [base / res for res in aniso_res] - return zoom_factors + base = max(aniso_res) + return [res / base for res in aniso_res] def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" @@ -720,15 +759,23 @@ def __init__( self.image = None self.layer_type = layer_type - self.layer_list = DropdownMenu(parent=self, label=name, fixed=False) + self.layer_list = DropdownMenu( + parent=self, text_label=name, fixed=False + ) + self.layer_description = make_label("Shape:", parent=self) + self.layer_description.setVisible(False) # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? self._viewer.layers.events.inserted.connect(partial(self._add_layer)) self._viewer.layers.events.removed.connect(partial(self._remove_layer)) self.layer_list.currentIndexChanged.connect(self._update_tooltip) + self.layer_list.currentTextChanged.connect(self._update_description) - add_widgets(self.layout, [self.layer_list.label, self.layer_list]) + add_widgets( + self.layout, + [self.layer_list.label, self.layer_list, self.layer_description], + ) self._check_for_layers() def _check_for_layers(self): @@ -739,6 +786,14 @@ def _check_for_layers(self): def _update_tooltip(self): self.layer_list.setToolTip(self.layer_list.currentText()) + def _update_description(self): + if self.layer_list.currentText() != "": + self.layer_description.setVisible(True) + shape_desc = f"Shape : {self.layer_data().shape}" + self.layer_description.setText(shape_desc) + else: + self.layer_description.setVisible(False) + def _add_layer(self, event): inserted_layer = event.value @@ -756,23 +811,26 @@ def _remove_layer(self, event): index = self.layer_list.findText(removed_layer.name) self.layer_list.removeItem(index) - def set_layer_type(self, type): # no @property due to Qt constraint - self.layer_type = type + def set_layer_type(self, layer_type): # no @property due to Qt constraint + self.layer_type = layer_type [self.layer_list.removeItem(i) for i in range(self.layer_list.count())] self._check_for_layers() def layer(self): - return self._viewer.layers[self.layer_name()] + try: + return self._viewer.layers[self.layer_name()] + except ValueError: + return None def layer_name(self): return self.layer_list.currentText() def layer_data(self): if self.layer_list.count() < 1: - warnings.warn("Please select a valid layer !") + logger.warning("Please select a valid layer !") return None - return self._viewer.layers[self.layer_name()].data + return self.layer().data class FilePathWidget(QWidget): # TODO include load as folder @@ -814,6 +872,9 @@ def __init__( self.build() self.check_ready() + if self._required: + self._text_field.textChanged.connect(self.check_ready) + def build(self): """Builds the layout of the widget""" add_widgets( @@ -855,9 +916,8 @@ def check_ready(self): self.update_field_color("indianred") self.text_field.setToolTip("Mandatory field !") return False - else: - self.update_field_color(f"{napari_param_darkgrey}") - return True + self.update_field_color(f"{napari_param_darkgrey}") + return True @property def required(self): @@ -869,10 +929,9 @@ def required(self, is_required): if is_required: self.text_field.textChanged.connect(self.check_ready) else: - try: + with contextlib.suppress(TypeError): self.text_field.textChanged.disconnect(self.check_ready) - except TypeError: - return + self.check_ready() self._required = is_required @@ -959,22 +1018,22 @@ def make_scrollable( def set_spinbox( box, - min=0, - max=10, + min_value=0, + max_value=10, default=0, step=1, fixed: Optional[bool] = True, ): """Args: box : QSpinBox or QDoubleSpinBox - min : minimum value, defaults to 0 - max : maximum value, defaults to 10 + min_value : minimum value, defaults to 0 + max_value : maximum value, defaults to 10 default : default value, defaults to 0 step : step value, defaults to 1 fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" - box.setMinimum(min) - box.setMaximum(max) + box.setMinimum(min_value) + box.setMaximum(max_value) box.setSingleStep(step) box.setValue(default) @@ -985,8 +1044,8 @@ def set_spinbox( def make_n_spinboxes( class_, n: int = 2, - min=0, - max=10, + min_value=0, + max_value=10, default=0, step=1, parent: Optional[QWidget] = None, @@ -997,8 +1056,8 @@ def make_n_spinboxes( Args: class_ : QSpinBox or QDoubleSpinbox n (int): number of increment counters to create - min (Optional[int]): minimum value, defaults to 0 - max (Optional[int]): maximum value, defaults to 10 + min_value (Optional[int]): minimum value, defaults to 0 + max_value (Optional[int]): maximum value, defaults to 10 default (Optional[int]): default value, defaults to 0 step (Optional[int]): step value, defaults to 1 parent: parent widget, defaults to None @@ -1009,7 +1068,7 @@ def make_n_spinboxes( boxes = [] for _i in range(n): - box = class_(min, max, default, step, parent, fixed) + box = class_(min_value, max_value, default, step, parent, fixed) boxes.append(box) return boxes @@ -1025,7 +1084,7 @@ def __init__( step: Optional[float] = 1.0, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, - label: Optional[str] = None, + text_label: Optional[str] = None, ): """Args: lower (Optional[float]): minimum value, defaults to 0 @@ -1034,7 +1093,7 @@ def __init__( step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed - label (Optional[str]): if provided, creates a label with the chosen title to use with the counter + text_label (Optional[str]): if provided, creates a label with the chosen title to use with the counter """ super().__init__(parent) @@ -1042,15 +1101,16 @@ def __init__( self.layout = None - if label is not None: - self.label = make_label(name=label) - self.valueChanged.connect(self._update_step) + if text_label is not None: + self.label = make_label(name=text_label) + # self.valueChanged.connect(self._update_step) + self.setStepType(QAbstractSpinBox.StepType.AdaptiveDecimalStepType) - def _update_step(self): # FIXME check divide_factor - if self.value() < 0.9: - self.setSingleStep(0.01) - else: - self.setSingleStep(0.1) + # def _update_step(self): + # if self.value() <= 1: + # self.setSingleStep(0.1) + # else: + # self.setSingleStep(1) @property def tooltips(self): @@ -1103,7 +1163,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, - label: Optional[str] = None, + text_label: Optional[str] = None, ): """Args: lower (Optional[int]): minimum value, defaults to 0 @@ -1119,8 +1179,8 @@ def __init__( self.label = None self.container = None - if label is not None: - self.label = make_label(name=label) + if text_label is not None: + self.label = make_label(name=text_label) @property def tooltips(self): @@ -1166,8 +1226,8 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, - possible_paths: list = [], - filetype: str = "Image file (*.tif *.tiff)", + possible_paths: list = (), + file_extension: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -1176,29 +1236,27 @@ def open_file_dialog( possible_paths (str): Paths that may have been chosen before, can be a string or an array of strings containing the paths load_as_folder (bool): Whether to open a folder or a single file. If True, will allow opening folder as a single file (2D stack interpreted as 3D) - filetype (str): The description and file extension to load (format : ``"Description (*.example1 *.example2)"``). Default ``"Image file (*.tif *.tiff)"`` + file_extension (str): The description and file extension to load (format : ``"Description (*.example1 *.example2)"``). Default ``"Image file (*.tif *.tiff)"`` """ default_path = utils.parse_default_path(possible_paths) - f_name = QFileDialog.getOpenFileName( - widget, "Choose file", default_path, filetype + return QFileDialog.getOpenFileName( + widget, "Choose file", default_path, file_extension ) - return f_name def open_folder_dialog( widget, - possible_paths: list = [], + possible_paths: list = (), ): default_path = utils.parse_default_path(possible_paths) logger.info(f"Default : {default_path}") - filenames = QFileDialog.getExistingDirectory( - widget, "Open directory", default_path + return QFileDialog.getExistingDirectory( + widget, "Open directory", default_path # + "/.." ) - return filenames def make_label(name, parent=None): # TODO update to child class @@ -1215,12 +1273,11 @@ def make_label(name, parent=None): # TODO update to child class label = QLabel(name, parent) if SHOW_LABELS_DEBUG_TOOLTIP: label.setToolTip(f"{label}") - return label else: label = QLabel(name) if SHOW_LABELS_DEBUG_TOOLTIP: label.setToolTip(f"{label}") - return label + return label def make_group(title, l=7, t=20, r=7, b=11, parent=None): @@ -1258,12 +1315,20 @@ def set_layout(self): @classmethod def create_single_widget_group( - cls, title, widget, layout, l=7, t=20, r=7, b=11 + cls, + title, + widget, + layout, + l=7, + t=20, + r=7, + b=11, + alignment=LEFT_AL, ): group = cls(title, l, t, r, b) group.layout.addWidget(widget) group.setLayout(group.layout) - layout.addWidget(group) + layout.addWidget(group, alignment=alignment) def add_widgets(layout, widgets, alignment=LEFT_AL): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index a52c3de9..663872c4 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,18 +1,23 @@ import logging -import warnings from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING, Union +import napari import numpy as np +from monai.transforms import Zoom from skimage import io from skimage.filters import gaussian -from tifffile import imread as tfl_imread +from tifffile import imread, imwrite + +if TYPE_CHECKING: + import torch LOGGER = logging.getLogger(__name__) ############### # Global logging level setting -# LOGGER.setLevel(logging.DEBUG) -LOGGER.setLevel(logging.INFO) +LOGGER.setLevel(logging.DEBUG) +# LOGGER.setLevel(logging.INFO) ############### """ utils.py @@ -21,6 +26,76 @@ """ +#################### +# viewer utils +def save_folder(results_path, folder_name, images, image_paths): + """ + Saves a list of images in a folder + + Args: + results_path: Path to the folder containing results + folder_name: Name of the folder containing results + images: List of images to save + image_paths: list of filenames of images + """ + results_folder = results_path / Path(folder_name) + results_folder.mkdir(exist_ok=False, parents=True) + + for file, image in zip(image_paths, images): + path = results_folder / Path(file).name + + imwrite( + path, + image, + ) + LOGGER.info(f"Saved processed folder as : {results_folder}") + + +def save_layer(results_path, image_name, image): + """ + Saves an image layer at the specified path + + Args: + results_path: path to folder containing result + image_name: image name for saving + image: data array containing image + + Returns: + + """ + path = str(results_path / Path(image_name)) # TODO flexible filetype + LOGGER.info(f"Saved as : {path}") + imwrite(path, image) + + +def show_result(viewer, layer, image, name): + """ + Adds layers to a viewer to show result to user + + Args: + viewer: viewer to add layer in + layer: original layer the operation was run on, to determine whether it should be an Image or Labels layer + image: the data array containing the image + name: name of the added layer + + Returns: + + """ + if isinstance(layer, napari.layers.Image): + LOGGER.debug("Added resulting image layer") + viewer.add_image(image, name=name) + elif isinstance(layer, napari.layers.Labels): + LOGGER.debug("Added resulting label layer") + viewer.add_labels(image, name=name) + else: + LOGGER.warning( + f"Results not shown, unsupported layer type {type(layer)}" + ) + + +#################### + + class Singleton(type): """ Singleton class that can only be instantiated once at a time, @@ -36,6 +111,18 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] +# class TiffFileReader(ImageReader): +# def __init__(self): +# super().__init__() +# +# def verify_suffix(self, filename): +# if filename == "tif": +# return True +# def read(self, data, **kwargs): +# return imread(data) +# +# def get_data(self, data): +# return data, {} def normalize_x(image): """Normalizes the values of an image array to be between [-1;1] rather than [0;255] @@ -45,8 +132,11 @@ def normalize_x(image): Returns: array: normalized value for the image """ - image = image / 127.5 - 1 - return image + return image / 127.5 - 1 + + +def mkdir_from_str(path: str, exist_ok=True, parents=True): + Path(path).resolve().mkdir(exist_ok=exist_ok, parents=parents) def normalize_y(image): @@ -58,8 +148,7 @@ def normalize_y(image): Returns: array: normalized value for the image """ - image = image / 255 - return image + return image / 255 def sphericity_volume_area(volume, surface_area): @@ -113,15 +202,45 @@ def dice_coeff(y_true, y_pred): y_true_f = y_true.flatten() y_pred_f = y_pred.flatten() intersection = np.sum(y_true_f * y_pred_f) - score = (2.0 * intersection + smooth) / ( + return (2.0 * intersection + smooth) / ( np.sum(y_true_f) + np.sum(y_pred_f) + smooth ) - return score -def resize(image, zoom_factors): - from monai.transforms import Zoom +def correct_rotation(image): + """Rotates the exes 0 and 2 in [DHW] section of image array""" + extra_dims = len(image.shape) - 3 + return np.swapaxes(image, 0 + extra_dims, 2 + extra_dims) + +def normalize_max(image): + """Normalizes an image using the max and min value""" + shape = image.shape + image = image.flatten() + image = (image - image.min()) / (image.max() - image.min()) + image = image.reshape(shape) + return image + + +def remap_image( + image: Union["np.ndarray", "torch.Tensor"], + new_max=100, + new_min=0, + prev_max=None, + prev_min=None, +): + """Normalizes a numpy array or Tensor using the max and min value""" + shape = image.shape + image = image.flatten() + im_max = prev_max if prev_max is not None else image.max() + im_min = prev_min if prev_min is not None else image.min() + image = (image - im_min) / (im_max - im_min) + image = image * (new_max - new_min) + new_min + image = image.reshape(shape) + return image + + +def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, keep_size=False, @@ -186,10 +305,11 @@ def time_difference(time_start, time_finish, as_string=True): minutes = f"{int(minutes[0])}".zfill(2) seconds = f"{int(seconds[0])}".zfill(2) - if as_string: - return f"{hours}:{minutes}:{seconds}" - else: - return [hours, minutes, seconds] + return ( + f"{hours}:{minutes}:{seconds}" + if as_string + else [hours, minutes, seconds] + ) def get_padding_dim(image_shape, anisotropy_factor=None): @@ -223,15 +343,15 @@ def get_padding_dim(image_shape, anisotropy_factor=None): size = int(size / anisotropy_factor[i]) while pad < size: # if size - pad < 30: - # warnings.warn( + # LOGGER.warning( # f"Your value is close to a lower power of two; you might want to choose slightly smaller" # f" sizes and/or crop your images down to {pad}" # ) pad = 2**n n += 1 - if pad >= 256: - warnings.warn( + if pad >= 1024: + LOGGER.warning( "Warning : a very large dimension for automatic padding has been computed.\n" "Ensure your images are of an appropriate size and/or that you have enough memory." f"The padding value is currently {pad}." @@ -331,14 +451,14 @@ def annotation_to_input(label_ermito): # pass -def fill_list_in_between(lst, n, elem): +def fill_list_in_between(lst, n, fill_value): """Fills a list with n * elem between each member of list. Example with list = [1,2,3], n=2, elem='&' : returns [1, &, &,2,&,&,3,&,&] Args: lst: list to fill n: number of elements to add - elem: added n times after each element of list + fill_value: added n times after each element of list Returns : Filled list @@ -347,13 +467,13 @@ def fill_list_in_between(lst, n, elem): for i in range(len(lst)): temp_list = [lst[i]] while len(temp_list) < n + 1: - temp_list.append(elem) + temp_list.append(fill_value) if i < len(lst) - 1: new_list += temp_list else: new_list.append(lst[i]) for _j in range(n): - new_list.append(elem) + new_list.append(fill_value) return new_list return None @@ -400,7 +520,7 @@ def parse_default_path(possible_paths): # ] print(default_paths) if len(default_paths) == 0: - return str(Path.home()) + return str(Path().home()) default_path = max(default_paths, key=len) return str(default_path) @@ -455,19 +575,12 @@ def load_images( raise ValueError("If loading as a folder, filetype must be specified") if as_folder: - try: - images_original = tfl_imread(filename_pattern_original) - except ValueError: - LOGGER.error( - "Loading a stack this way is no longer supported. Use napari to load a stack." - ) - - else: - images_original = tfl_imread( - filename_pattern_original - ) # tifffile imread + raise NotImplementedError( + "Loading as folder not implemented yet. Use napari to load as folder" + ) + # images_original = dask_imread(filename_pattern_original) - return images_original + return imread(filename_pattern_original) # tifffile imread # def load_predicted_masks(mito_mask_dir, er_mask_dir, filetype): @@ -527,26 +640,26 @@ def select_train_data(dataframe, ori_imgs, label_imgs, ori_filenames): return np.array(train_ori_imgs), np.array(train_label_imgs) -def format_Warning(message, category, filename, lineno, line=""): - """Formats a warning message, use in code with ``warnings.formatwarning = utils.format_Warning`` - - Args: - message: warning message - category: which type of warning has been raised - filename: file - lineno: line number - line: unused - - Returns: format - - """ - return ( - str(filename) - + ":" - + str(lineno) - + ": " - + category.__name__ - + ": " - + str(message) - + "\n" - ) +# def format_Warning(message, category, filename, lineno, line=""): +# """Formats a warning message, use in code with ``warnings.formatwarning = utils.format_Warning`` +# +# Args: +# message: warning message +# category: which type of warning has been raised +# filename: file +# lineno: line number +# line: unused +# +# Returns: format +# +# """ +# return ( +# str(filename) +# + ":" +# + str(lineno) +# + ": " +# + category.__name__ +# + ": " +# + str(message) +# + "\n" +# ) diff --git a/notebooks/assess_instance.ipynb b/notebooks/assess_instance.ipynb index b8810301..0dec4543 100644 --- a/notebooks/assess_instance.ipynb +++ b/notebooks/assess_instance.ipynb @@ -47,7 +47,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -57,6 +57,7 @@ ], "source": [ "im_path = Path(\"C:/Users/Cyril/Desktop/test/instance_test\")\n", + "# prediction_path = str(im_path / \"trailmap_ms/trailmap_pred.tif\")\n", "prediction_path = str(im_path / \"pred.tif\")\n", "gt_labels_path = str(im_path / \"labels_relabel_unique.tif\")\n", "\n", @@ -65,6 +66,7 @@ "\n", "zoom = (1 / 5, 1, 1)\n", "prediction_resized = resize(prediction, zoom)\n", + "# prediction_resized = prediction # for trailmap\n", "gt_labels_resized = resize(gt_labels, zoom)\n", "\n", "\n", @@ -85,7 +87,7 @@ { "data": { "text/plain": [ - "0.5817600487210719" + "0.8592223181276479" ] }, "execution_count": 4, @@ -96,9 +98,15 @@ "source": [ "from napari_cellseg3d.utils import dice_coeff\n", "\n", + "semantic_gt = to_semantic(gt_labels_resized.copy())\n", + "semantic_pred = to_semantic(prediction_resized.copy())\n", + "\n", + "viewer.add_image(semantic_gt, colormap='bop blue')\n", + "viewer.add_image(semantic_pred, colormap='red')\n", + "\n", "dice_coeff(\n", - " to_semantic(gt_labels_resized.copy()),\n", - " to_semantic(prediction_resized.copy()),\n", + " semantic_gt,\n", + " prediction_resized\n", ")" ] }, @@ -171,7 +179,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 8, @@ -198,24 +206,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,057 - Mapping labels...\n" + "2023-03-31 15:37:19,775 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.18it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3699.66it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,092 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,094 - Percent of non-fused neurons found: 52.00%\n", - "2023-03-22 15:48:47,095 - Percent of fused neurons found: 36.80%\n", - "2023-03-22 15:48:47,095 - Overall percent of neurons found: 88.80%\n" + "2023-03-31 15:37:19,812 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,815 - Percent of non-fused neurons found: 52.00%\n", + "2023-03-31 15:37:19,816 - Percent of fused neurons found: 36.80%\n", + "2023-03-31 15:37:19,817 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -262,24 +270,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,168 - Mapping labels...\n" + "2023-03-31 15:37:19,919 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3454.21it/s]" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3992.79it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,201 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,203 - Percent of non-fused neurons found: 54.40%\n", - "2023-03-22 15:48:47,203 - Percent of fused neurons found: 34.40%\n", - "2023-03-22 15:48:47,204 - Overall percent of neurons found: 88.80%\n" + "2023-03-31 15:37:19,949 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:19,952 - Percent of non-fused neurons found: 54.40%\n", + "2023-03-31 15:37:19,953 - Percent of fused neurons found: 34.40%\n", + "2023-03-31 15:37:19,953 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -326,6 +334,40 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:37:21,076 - build program: kernel 'gaussian_blur_separable_3d' was part of a lengthy source build resulting from a binary cache miss (0.88 s)\n", + "2023-03-31 15:37:21,514 - build program: kernel 'copy_3d' was part of a lengthy source build resulting from a binary cache miss (0.42 s)\n", + "2023-03-31 15:37:22,021 - build program: kernel 'detect_maxima_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:22,642 - build program: kernel 'minimum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.59 s)\n", + "2023-03-31 15:37:23,117 - build program: kernel 'minimum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:23,651 - build program: kernel 'minimum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,188 - build program: kernel 'maximum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.52 s)\n", + "2023-03-31 15:37:24,801 - build program: kernel 'maximum_y_projection' was part of a lengthy source build resulting from a binary cache miss (0.60 s)\n", + "2023-03-31 15:37:25,263 - build program: kernel 'maximum_x_projection' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:25,766 - build program: kernel 'histogram_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:26,256 - build program: kernel 'sum_z_projection' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:26,699 - build program: kernel 'greater_constant_3d' was part of a lengthy source build resulting from a binary cache miss (0.43 s)\n", + "2023-03-31 15:37:27,158 - build program: kernel 'binary_and_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:27,635 - build program: kernel 'add_image_and_scalar_3d' was part of a lengthy source build resulting from a binary cache miss (0.47 s)\n", + "2023-03-31 15:37:28,128 - build program: kernel 'set_nonzero_pixels_to_pixelindex' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:28,580 - build program: kernel 'set_3d' was part of a lengthy source build resulting from a binary cache miss (0.45 s)\n", + "2023-03-31 15:37:29,076 - build program: kernel 'nonzero_minimum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.49 s)\n", + "2023-03-31 15:37:29,551 - build program: kernel 'set_2d' was part of a lengthy source build resulting from a binary cache miss (0.46 s)\n", + "2023-03-31 15:37:30,035 - build program: kernel 'flag_existing_labels' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:30,544 - build program: kernel 'set_column_2d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:31,033 - build program: kernel 'sum_reduction_x' was part of a lengthy source build resulting from a binary cache miss (0.48 s)\n", + "2023-03-31 15:37:31,572 - build program: kernel 'block_enumerate' was part of a lengthy source build resulting from a binary cache miss (0.53 s)\n", + "2023-03-31 15:37:32,094 - build program: kernel 'replace_intensities' was part of a lengthy source build resulting from a binary cache miss (0.51 s)\n", + "2023-03-31 15:37:32,685 - build program: kernel 'add_images_weighted_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:33,256 - build program: kernel 'onlyzero_overwrite_maximum_box_3d' was part of a lengthy source build resulting from a binary cache miss (0.56 s)\n", + "2023-03-31 15:37:33,845 - build program: kernel 'onlyzero_overwrite_maximum_diamond_3d' was part of a lengthy source build resulting from a binary cache miss (0.58 s)\n", + "2023-03-31 15:37:34,369 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n", + "2023-03-31 15:37:34,888 - build program: kernel 'mask_3d' was part of a lengthy source build resulting from a binary cache miss (0.50 s)\n" + ] + }, { "data": { "text/plain": [ @@ -338,7 +380,7 @@ } ], "source": [ - "voronoi = voronoi_otsu(prediction_resized, 1, outline_sigma=1)\n", + "voronoi = voronoi_otsu(prediction_resized, 0.6, outline_sigma=0.7)\n", "\n", "from skimage.morphology import remove_small_objects\n", "\n", @@ -414,24 +456,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,570 - Mapping labels...\n" + "2023-03-31 15:37:36,854 - Mapping labels...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 116/116 [00:00<00:00, 3527.67it/s]" + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 611.96it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2023-03-22 15:48:47,607 - Calculating the number of neurons not found...\n", - "2023-03-22 15:48:47,609 - Percent of non-fused neurons found: 79.20%\n", - "2023-03-22 15:48:47,609 - Percent of fused neurons found: 9.60%\n", - "2023-03-22 15:48:47,610 - Overall percent of neurons found: 88.80%\n" + "2023-03-31 15:37:37,087 - Calculating the number of neurons not found...\n", + "2023-03-31 15:37:37,098 - Percent of non-fused neurons found: 87.20%\n", + "2023-03-31 15:37:37,104 - Percent of fused neurons found: 1.60%\n", + "2023-03-31 15:37:37,114 - Overall percent of neurons found: 88.80%\n" ] }, { @@ -444,15 +486,15 @@ { "data": { "text/plain": [ - "(99,\n", - " 12,\n", + "(109,\n", + " 2,\n", " 13,\n", - " 17,\n", - " 0.6286692001809993,\n", - " 0.9378875115172982,\n", - " 0.949109422876503,\n", - " 0.5827007113964422,\n", - " 0.7306099091287442)" + " 8,\n", + " 0.8285521200005869,\n", + " 0.8809251900364068,\n", + " 0.9838709677419355,\n", + " 0.782258064516129,\n", + " 1.0)" ] }, "execution_count": 15, @@ -473,10 +515,25 @@ "outputs_hidden": false } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-31 15:40:34,683 - No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'\n" + ] + } + ], "source": [ "# eval.evaluate_model_performance(gt_labels_resized, voronoi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -495,7 +552,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/notebooks/full_plot.ipynb b/notebooks/full_plot.ipynb index 5c640e1b..87f973f9 100644 --- a/notebooks/full_plot.ipynb +++ b/notebooks/full_plot.ipynb @@ -10,7 +10,6 @@ "import matplotlib.pyplot as plt\n", "import os\n", "import numpy as np\n", - "from PIL import Image\n", "from tifffile import imread" ] }, diff --git a/notebooks/train_wnet.ipynb b/notebooks/train_wnet.ipynb new file mode 100644 index 00000000..4fb6c0f4 --- /dev/null +++ b/notebooks/train_wnet.ipynb @@ -0,0 +1,267 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:14.017741900Z", + "start_time": "2023-07-10T08:00:14.007742500Z" + } + }, + "outputs": [], + "source": [ + "from napari_cellseg3d.code_models.models.wnet.train_wnet import Config, train\n", + "from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "config = Config()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:14.382675700Z", + "start_time": "2023-07-10T08:00:14.354604Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Basic config :" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "config.num_epochs = 100\n", + "config.val_interval = 1 # performs validation with test dataset every n epochs\n", + "config.batch_size = 1" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:15.040773600Z", + "start_time": "2023-07-10T08:00:15.020804400Z" + } + } + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Image directories :\n", + "- `train_volume_directory` : The path to the folder containing the 3D .tif files on which to train\n", + "- `eval_volume_directory` : If available, the path to the validation set to compute Dice metric on; labels should be in a \"lab\" folder, volumes in \"vol\" at the specified path. Images and labels should match when sorted alphabetically" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "config.train_volume_directory = str(Path.home() / \"Desktop/Code/WNet-benchmark/dataset/VIP_small\")\n", + "config.eval_volume_directory = None\n", + "\n", + "config.save_model_path = \"./results\"" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:15.810624400Z", + "start_time": "2023-07-10T08:00:15.791682400Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### Advanced config\n", + "Note : more parameters can be found in the config.py file, depending on your needs" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "config.in_channels = 1\n", + "config.out_channels = 1\n", + "config.num_classes = 2\n", + "config.dropout = 0.65\n", + "\n", + "config.lr = 1e-6 # learning rate\n", + "config.scheduler = \"None\" # \"CosineAnnealingLR\" # \"ReduceLROnPlateau\" # can be further tweaked in config\n", + "config.weight_decay = 0.01 # None" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-07-10T08:00:16.455904800Z", + "start_time": "2023-07-10T08:00:16.445901900Z" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n", + "Config:\n", + "('in_channels', 1)\n", + "('out_channels', 1)\n", + "('num_classes', 2)\n", + "('dropout', 0.65)\n", + "('use_clipping', False)\n", + "('clipping', 1)\n", + "('lr', 1e-06)\n", + "('scheduler', 'None')\n", + "('weight_decay', 0.01)\n", + "('intensity_sigma', 1)\n", + "('spatial_sigma', 4)\n", + "('radius', 2)\n", + "('n_cuts_weight', 0.5)\n", + "('reconstruction_loss', 'MSE')\n", + "('rec_loss_weight', 0.005)\n", + "('num_epochs', 100)\n", + "('val_interval', 1)\n", + "('batch_size', 1)\n", + "('num_workers', 4)\n", + "('sa', 50)\n", + "('sb', 20)\n", + "('sg', 1)\n", + "('w1', 50)\n", + "('w2', 20)\n", + "('n_iter', 5)\n", + "('train_volume_directory', 'C:\\\\Users\\\\Cyril\\\\Desktop\\\\Code\\\\WNet-benchmark\\\\dataset\\\\VIP_small')\n", + "('eval_volume_directory', None)\n", + "('normalize_input', True)\n", + "('normalizing_function', )\n", + "('use_patch', False)\n", + "('patch_size', (64, 64, 64))\n", + "('num_patches', 30)\n", + "('eval_num_patches', 20)\n", + "('do_augmentation', True)\n", + "('parallel', False)\n", + "('save_model', True)\n", + "('save_model_path', './../results/new_model/wnet_new_model_all_data_3class.pth')\n", + "('save_every', 5)\n", + "('weights_path', None)\n", + "Initializing training...\n", + "Getting the data\n", + "2023-07-10 10:00:17,137 - Images :\n", + "2023-07-10 10:00:17,137 - 1\n", + "2023-07-10 10:00:17,137 - 2\n", + "2023-07-10 10:00:17,137 - **********\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████| 2/2 [00:00=3.8" dependencies = [ "numpy", - "napari>=0.4.14", + "napari[all]>=0.4.14", "QtPy", "opencv-python>=4.5.5", +# "dask-image>=0.6.0", "scikit-image>=0.19.2", "matplotlib>=3.4.1", "tifffile>=2022.2.9", "imageio-ffmpeg>=0.4.5", + "imagecodecs>=2023.3.16", "torch>=1.11", "monai[nibabel,einops]>=0.9.0", + "itk", "tqdm", "nibabel", "scikit-image", @@ -92,15 +95,15 @@ profile = "black" line_length = 79 [project.optional-dependencies] -all = [ - "napari[all]>=0.4.14", +crf = [ + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", ] dev = [ "isort", "black", "ruff", - "tuna", "pre-commit", + "tuna", ] docs = [ "sphinx", @@ -111,7 +114,17 @@ docs = [ test = [ "pytest", "pytest_qt", + "pytest-cov", "coverage", "tox", "twine", + "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master", +] +onnx-cpu = [ + "onnx", + "onnxruntime" +] +onnx-gpu = [ + "onnx", + "onnxruntime-gpu" ] diff --git a/requirements.txt b/requirements.txt index 3189e9c4..ada03ae4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ black coverage +imageio-ffmpeg>=0.4.5 isort itk pytest @@ -15,13 +16,11 @@ QtPy opencv-python>=4.5.5 pre-commit pyclesperanto-prototype>=0.22.0 -pysqlite3 -dask-image>=0.6.0 matplotlib>=3.4.1 +ruff tifffile>=2022.2.9 -imageio-ffmpeg>=0.4.5 torch>=1.11 -monai[nibabel,einops]>=1.0.1 +monai[nibabel,einops,tifffile]>=1.0.1 pillow scikit-image>=0.19.2 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index 3a0bdaae..7a72482a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,35 @@ +[metadata] +name = napari-cellseg3d +version = 0.0.3rc1 +author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis +author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu + +license = MIT +description = plugin for cell segmentation +long_description = file: README.md +long_description_content_type = text/markdown +classifiers = + Development Status :: 2 - Pre-Alpha + Intended Audience :: Science/Research + Framework :: napari + Topic :: Software Development :: Testing + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Operating System :: OS Independent + License :: OSI Approved :: MIT License + Topic :: Scientific/Engineering :: Artificial Intelligence + Topic :: Scientific/Engineering :: Image Processing + Topic :: Scientific/Engineering :: Visualization + +url = https://github.com/AdaptiveMotorControlLab/CellSeg3d +project_urls = + Bug Tracker = https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues + Documentation = https://adaptivemotorcontrollab.github.io/cellseg3d-docs/res/welcome.html + Source Code = https://github.com/AdaptiveMotorControlLab/CellSeg3d + [options] packages = find: include_package_data = True @@ -5,9 +37,37 @@ python_requires = >=3.8 package_dir = =. +# add your package requirements here +install_requires = + numpy + napari[all]>=0.4.14 + QtPy + opencv-python>=4.5.5 + scikit-image>=0.19.2 + matplotlib>=3.4.1 + tifffile>=2022.2.9 + imageio-ffmpeg>=0.4.5 + torch>=1.11 + monai[nibabel,einops,tifffile]>=1.0.1 + itk + tqdm + nibabel + pyclesperanto-prototype + scikit-image + pillow + tqdm + matplotlib + vispy>=0.9.6 + [options.packages.find] where = . +[options.package_data] +napari-cellseg3d = + res/*.png + code_models/models/pretrained/*.json + napari.yaml + [options.entry_points] napari.manifest = napari-cellseg3d = napari_cellseg3d:napari.yaml diff --git a/tox.ini b/tox.ini index 87338cd8..1b9b5e22 100644 --- a/tox.ini +++ b/tox.ini @@ -29,13 +29,14 @@ passenv = deps = pytest # https://docs.pytest.org/en/latest/contents.html pytest-cov # https://pytest-cov.readthedocs.io/en/latest/ -; dask-image -; # you can remove these if you don't use them napari PyQt5 magicgui pytest-qt qtpy + git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf ; pyopencl[pocl] - +; opencv-python +extras = crf +usedevelop = true commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml